diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 8173bee58e8e24..7c632f8a34d56a 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -76,7 +76,7 @@ jobs: - name: Run Workflow run: poetry run -C api bash dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | @@ -90,5 +90,6 @@ jobs: pgvecto-rs pgvector chroma + elasticsearch - name: Test Vector Stores run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index 3418bf0c6f6688..ae3e0ee69d8cfb 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml +yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs." \ No newline at end of file +echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch" \ No newline at end of file diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index f6092c86337d85..d681dc66276dd1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -45,6 +45,10 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example + - name: Ruff formatter check + if: steps.changed-files.outputs.any_changed == 'true' + run: poetry run -C api ruff format --check ./api + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/.gitignore b/.gitignore index c52b9d8bbf5bca..22aba6536479e8 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,4 @@ pyrightconfig.json api/.vscode .idea/ +.vscode \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f810584f24115c..8f57cd545ee308 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr ## Before you jump in -[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: +[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: ### Feature requests: diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 303c2513f53b9b..7cd2bb60ebca04 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -8,7 +8,7 @@ ## 在开始之前 -[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: +[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: ### 功能请求: diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index 6d5bfb205c31dc..a68bdeddbc830f 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは ## 飛び込む前に -[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 +[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 ### 機能リクエスト diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md new file mode 100644 index 00000000000000..80e68a046ec5fc --- /dev/null +++ b/CONTRIBUTING_VI.md @@ -0,0 +1,156 @@ +Thật tuyệt vời khi bạn muốn đóng góp cho Dify! Chúng tôi rất mong chờ được thấy những gì bạn sẽ làm. Là một startup với nguồn nhân lực và tài chính hạn chế, chúng tôi có tham vọng lớn là thiết kế quy trình trực quan nhất để xây dựng và quản lý các ứng dụng LLM. Mọi sự giúp đỡ từ cộng đồng đều rất quý giá đối với chúng tôi. + +Chúng tôi cần linh hoạt và làm việc nhanh chóng, nhưng đồng thời cũng muốn đảm bảo các cộng tác viên như bạn có trải nghiệm đóng góp thuận lợi nhất có thể. Chúng tôi đã tạo ra hướng dẫn đóng góp này nhằm giúp bạn làm quen với codebase và cách chúng tôi làm việc với các cộng tác viên, để bạn có thể nhanh chóng bắt tay vào phần thú vị. + +Hướng dẫn này, cũng như bản thân Dify, đang trong quá trình cải tiến liên tục. Chúng tôi rất cảm kích sự thông cảm của bạn nếu đôi khi nó không theo kịp dự án thực tế, và chúng tôi luôn hoan nghênh mọi phản hồi để cải thiện. + +Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [Thỏa thuận Cấp phép và Đóng góp](./LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân thủ [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). + +## Trước khi bắt đầu + +[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại: + +### Yêu cầu tính năng: + +* Nếu bạn đang tạo một yêu cầu tính năng mới, chúng tôi muốn bạn giải thích tính năng đề xuất sẽ đạt được điều gì và cung cấp càng nhiều thông tin chi tiết càng tốt. [@perzeusss](https://github.com/perzeuss) đã tạo một [Trợ lý Yêu cầu Tính năng](https://udify.app/chat/MK2kVSnw1gakVwMX) rất hữu ích để giúp bạn soạn thảo nhu cầu của mình. Hãy thử dùng nó nhé. + +* Nếu bạn muốn chọn một vấn đề từ danh sách hiện có, chỉ cần để lại bình luận dưới vấn đề đó nói rằng bạn sẽ làm. + + Một thành viên trong nhóm làm việc trong lĩnh vực liên quan sẽ được thông báo. Nếu mọi thứ ổn, họ sẽ cho phép bạn bắt đầu code. Chúng tôi yêu cầu bạn chờ đợi cho đến lúc đó trước khi bắt tay vào làm tính năng, để không lãng phí công sức của bạn nếu chúng tôi đề xuất thay đổi. + + Tùy thuộc vào lĩnh vực mà tính năng đề xuất thuộc về, bạn có thể nói chuyện với các thành viên khác nhau trong nhóm. Dưới đây là danh sách các lĩnh vực mà các thành viên trong nhóm chúng tôi đang làm việc hiện tại: + + | Thành viên | Phạm vi | + | ------------------------------------------------------------ | ---------------------------------------------------- | + | [@yeuoly](https://github.com/Yeuoly) | Thiết kế kiến trúc Agents | + | [@jyong](https://github.com/JohnJyong) | Thiết kế quy trình RAG | + | [@GarfieldDai](https://github.com/GarfieldDai) | Xây dựng quy trình làm việc | + | [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Làm cho giao diện người dùng dễ sử dụng | + | [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Trải nghiệm nhà phát triển, đầu mối liên hệ cho mọi vấn đề | + | [@takatost](https://github.com/takatost) | Định hướng và kiến trúc tổng thể sản phẩm | + + Cách chúng tôi ưu tiên: + + | Loại tính năng | Mức độ ưu tiên | + | ------------------------------------------------------------ | -------------- | + | Tính năng ưu tiên cao được gắn nhãn bởi thành viên trong nhóm | Ưu tiên cao | + | Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) của chúng tôi | Ưu tiên trung bình | + | Tính năng không quan trọng và cải tiến nhỏ | Ưu tiên thấp | + | Có giá trị nhưng không cấp bách | Tính năng tương lai | + +### Những vấn đề khác (ví dụ: báo cáo lỗi, tối ưu hiệu suất, sửa lỗi chính tả): + +* Bắt đầu code ngay lập tức. + + Cách chúng tôi ưu tiên: + + | Loại vấn đề | Mức độ ưu tiên | + | ------------------------------------------------------------ | -------------- | + | Lỗi trong các chức năng chính (không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Nghiêm trọng | + | Lỗi không quan trọng, cải thiện hiệu suất | Ưu tiên trung bình | + | Sửa lỗi nhỏ (lỗi chính tả, giao diện người dùng gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp | + + +## Cài đặt + +Dưới đây là các bước để thiết lập Dify cho việc phát triển: + +### 1. Fork repository này + +### 2. Clone repository + + Clone repository đã fork từ terminal của bạn: + +``` +git clone git@github.com:/dify.git +``` + +### 3. Kiểm tra các phụ thuộc + +Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đã được cài đặt trên hệ thống của bạn: + +- [Docker](https://www.docker.com/) +- [Docker Compose](https://docs.docker.com/compose/install/) +- [Node.js v18.x (LTS)](http://nodejs.org) +- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) +- [Python](https://www.python.org/) phiên bản 3.10.x + +### 4. Cài đặt + +Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt. + +Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. + +### 5. Truy cập Dify trong trình duyệt của bạn + +Để xác nhận cài đặt của bạn, hãy truy cập [http://localhost:3000](http://localhost:3000) (địa chỉ mặc định, hoặc URL và cổng bạn đã cấu hình) trong trình duyệt. Bạn sẽ thấy Dify đang chạy. + +## Phát triển + +Nếu bạn đang thêm một nhà cung cấp mô hình, [hướng dẫn này](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) dành cho bạn. + +Nếu bạn đang thêm một nhà cung cấp công cụ cho Agent hoặc Workflow, [hướng dẫn này](./api/core/tools/README.md) dành cho bạn. + +Để giúp bạn nhanh chóng định hướng phần đóng góp của mình, dưới đây là một bản phác thảo ngắn gọn về cấu trúc backend & frontend của Dify: + +### Backend + +Backend của Dify được viết bằng Python sử dụng [Flask](https://flask.palletsprojects.com/en/3.0.x/). Nó sử dụng [SQLAlchemy](https://www.sqlalchemy.org/) cho ORM và [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) cho hàng đợi tác vụ. Logic xác thực được thực hiện thông qua Flask-login. + +``` +[api/] +├── constants // Các cài đặt hằng số được sử dụng trong toàn bộ codebase. +├── controllers // Định nghĩa các route API và logic xử lý yêu cầu. +├── core // Điều phối ứng dụng cốt lõi, tích hợp mô hình và công cụ. +├── docker // Cấu hình liên quan đến Docker & containerization. +├── events // Xử lý và xử lý sự kiện +├── extensions // Mở rộng với các framework/nền tảng bên thứ 3. +├── fields // Định nghĩa trường cho serialization/marshalling. +├── libs // Thư viện và tiện ích có thể tái sử dụng. +├── migrations // Script cho việc di chuyển cơ sở dữ liệu. +├── models // Mô hình cơ sở dữ liệu & định nghĩa schema. +├── services // Xác định logic nghiệp vụ. +├── storage // Lưu trữ khóa riêng tư. +├── tasks // Xử lý các tác vụ bất đồng bộ và công việc nền. +└── tests +``` + +### Frontend + +Website được khởi tạo trên boilerplate [Next.js](https://nextjs.org/) bằng Typescript và sử dụng [Tailwind CSS](https://tailwindcss.com/) cho styling. [React-i18next](https://react.i18next.com/) được sử dụng cho việc quốc tế hóa. + +``` +[web/] +├── app // layouts, pages và components +│ ├── (commonLayout) // layout chung được sử dụng trong toàn bộ ứng dụng +│ ├── (shareLayout) // layouts được chia sẻ cụ thể cho các phiên dựa trên token +│ ├── activate // trang kích hoạt +│ ├── components // được chia sẻ bởi các trang và layouts +│ ├── install // trang cài đặt +│ ├── signin // trang đăng nhập +│ └── styles // styles được chia sẻ toàn cục +├── assets // Tài nguyên tĩnh +├── bin // scripts chạy ở bước build +├── config // cài đặt và tùy chọn có thể điều chỉnh +├── context // contexts được chia sẻ bởi các phần khác nhau của ứng dụng +├── dictionaries // File dịch cho từng ngôn ngữ +├── docker // cấu hình container +├── hooks // Hooks có thể tái sử dụng +├── i18n // Cấu hình quốc tế hóa +├── models // mô tả các mô hình dữ liệu & hình dạng của phản hồi API +├── public // tài nguyên meta như favicon +├── service // xác định hình dạng của các hành động API +├── test +├── types // mô tả các tham số hàm và giá trị trả về +└── utils // Các hàm tiện ích được chia sẻ +``` + +## Gửi PR của bạn + +Cuối cùng, đã đến lúc mở một pull request (PR) đến repository của chúng tôi. Đối với các tính năng lớn, chúng tôi sẽ merge chúng vào nhánh `deploy/dev` để kiểm tra trước khi đưa vào nhánh `main`. Nếu bạn gặp vấn đề như xung đột merge hoặc không biết cách mở pull request, hãy xem [hướng dẫn về pull request của GitHub](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). + +Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giới thiệu là một người đóng góp trong [README](https://github.com/langgenius/dify/blob/main/README.md) của chúng tôi. + +## Nhận trợ giúp + +Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. \ No newline at end of file diff --git a/README.md b/README.md index 0dd0d862549afc..1c49c415fe09a9 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Dify Cloud · Self-hosting · Documentation · - Enterprise inquiry + Enterprise inquiry

@@ -38,6 +38,7 @@ README in Korean README بالعربية Türkçe README + README Tiếng Việt

@@ -65,7 +66,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. **5. Agent capabilities**: - You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DELL·E, Stable Diffusion and WolframAlpha. + You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. **6. LLMOps**: Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. @@ -151,7 +152,7 @@ Quickly get Dify running in your environment with this [starter guide](#quick-st Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - **Dify for enterprise / organizations
** -We provide additional enterprise-centric features. [Schedule a meeting with us](https://cal.com/guchenhe/30min) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
+We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding. @@ -220,23 +221,6 @@ At the same time, please consider supporting Dify by sharing it on social media * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. * [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. -Or, schedule a meeting directly with a team member: - - - - - - - - - - - - - - -
Point of ContactPurpose
Git-Hub-README-Button-3xBusiness enquiries & product feedback
Git-Hub-README-Button-2xContributions, issues & feature requests
- ## Star history [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_AR.md b/README_AR.md index 25229ef460dd24..10d572cc49a83b 100644 --- a/README_AR.md +++ b/README_AR.md @@ -4,7 +4,7 @@ Dify Cloud · الاستضافة الذاتية · التوثيق · - استفسارات الشركات + استفسار الشركات (للإنجليزية فقط)

@@ -38,6 +38,7 @@ README in Korean README بالعربية Türkçe README + README Tiếng Việt

@@ -57,7 +58,7 @@ **4. خط أنابيب RAG**: قدرات RAG الواسعة التي تغطي كل شيء من استيعاب الوثائق إلى الاسترجاع، مع الدعم الفوري لاستخراج النص من ملفات PDF و PPT وتنسيقات الوثائق الشائعة الأخرى. -**5. قدرات الوكيل**: يمكنك تعريف الوكلاء بناءً على أمر وظيفة LLM أو ReAct، وإضافة أدوات مدمجة أو مخصصة للوكيل. توفر Dify أكثر من 50 أداة مدمجة لوكلاء الذكاء الاصطناعي، مثل البحث في Google و DELL·E وStable Diffusion و WolframAlpha. +**5. قدرات الوكيل**: يمكنك تعريف الوكلاء بناءً على أمر وظيفة LLM أو ReAct، وإضافة أدوات مدمجة أو مخصصة للوكيل. توفر Dify أكثر من 50 أداة مدمجة لوكلاء الذكاء الاصطناعي، مثل البحث في Google و DALL·E وStable Diffusion و WolframAlpha. **6. الـ LLMOps**: راقب وتحلل سجلات التطبيق والأداء على مر الزمن. يمكنك تحسين الأوامر والبيانات والنماذج باستمرار استنادًا إلى البيانات الإنتاجية والتعليقات. @@ -203,23 +204,6 @@ docker compose up -d * [Discord](https://discord.gg/FngNHpbcY7). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع. * [تويتر](https://twitter.com/dify_ai). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع. -أو، قم بجدولة اجتماع مباشرة مع أحد أعضاء الفريق: - - - - - - - - - - - - - - -
نقطة الاتصالالغرض
Git-Hub-README-Button-3xاستفسارات الأعمال واقتراحات حول المنتج
Git-Hub-README-Button-2xالمساهمات والمشكلات وطلبات الميزات
- ## تاريخ النجمة [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_CN.md b/README_CN.md index ebbbe2590209dc..32551fcc313932 100644 --- a/README_CN.md +++ b/README_CN.md @@ -4,7 +4,7 @@ Dify 云服务 · 自托管 · 文档 · - 预约演示 + (需用英文)常见问题解答 / 联系团队

@@ -29,14 +29,16 @@

- 上个月的提交次数 - 上个月的提交次数 - 上个月的提交次数 - 上个月的提交次数 - 上个月的提交次数 - 上个月的提交次数 - 上个月的提交次数 + README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية Türkçe README + README Tiếng Việt
@@ -70,7 +72,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。 **5. Agent 智能体**: - 您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DELL·E、Stable Diffusion 和 WolframAlpha 等。 + 您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。 **6. LLMOps**: 随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。 @@ -156,7 +158,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。 - **面向企业/组织的 Dify
** -我们提供额外的面向企业的功能。[与我们安排会议](https://cal.com/guchenhe/30min)或[给我们发送电子邮件](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)讨论企业需求。
+我们提供额外的面向企业的功能。[给我们发送电子邮件](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)讨论企业需求。
> 对于使用 AWS 的初创公司和中小型企业,请查看 [AWS Marketplace 上的 Dify 高级版](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6),并使用一键部署到您自己的 AWS VPC。它是一个价格实惠的 AMI 产品,提供了使用自定义徽标和品牌创建应用程序的选项。 ## 保持领先 diff --git a/README_ES.md b/README_ES.md index 52b78da47d0f40..2ae044b32883b9 100644 --- a/README_ES.md +++ b/README_ES.md @@ -4,7 +4,7 @@ Dify Cloud · Auto-alojamiento · Documentación · - Programar demostración + Consultas empresariales (en inglés)

@@ -29,14 +29,16 @@

- Actividad de Commits el último mes - Actividad de Commits el último mes - Actividad de Commits el último mes - Actividad de Commits el último mes - Actividad de Commits el último mes - Actividad de Commits el último mes - Actividad de Commits el último mes + README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية Türkçe README + README Tiếng Việt

# @@ -70,7 +72,7 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto. **5. Capacidades de agente**: Puedes definir agent -es basados en LLM Function Calling o ReAct, y agregar herramientas preconstruidas o personalizadas para el agente. Dify proporciona más de 50 herramientas integradas para agentes de IA, como Búsqueda de Google, DELL·E, Difusión Estable y WolframAlpha. +es basados en LLM Function Calling o ReAct, y agregar herramientas preconstruidas o personalizadas para el agente. Dify proporciona más de 50 herramientas integradas para agentes de IA, como Búsqueda de Google, DALL·E, Difusión Estable y WolframAlpha. **6. LLMOps**: Supervisa y analiza registros de aplicaciones y rendimiento a lo largo del tiempo. Podrías mejorar continuamente prompts, conjuntos de datos y modelos basados en datos de producción y anotaciones. @@ -156,7 +158,7 @@ Pon rápidamente Dify en funcionamiento en tu entorno con esta [guía de inicio Usa nuestra [documentación](https://docs.dify.ai) para más referencias e instrucciones más detalladas. - **Dify para Empresas / Organizaciones
** -Proporcionamos características adicionales centradas en la empresa. [Programa una reunión con nosotros](https://cal.com/guchenhe/30min) o [envíanos un correo electrónico](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir las necesidades empresariales.
+Proporcionamos características adicionales centradas en la empresa. [Envíanos un correo electrónico](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir las necesidades empresariales.
> Para startups y pequeñas empresas que utilizan AWS, echa un vistazo a [Dify Premium en AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e impleméntalo en tu propio VPC de AWS con un clic. Es una AMI asequible que ofrece la opción de crear aplicaciones con logotipo y marca personalizados. @@ -228,23 +230,6 @@ Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en * [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. * [Twitter](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. -O, programa una reunión directamente con un miembro del equipo: - - - - - - - - - - - - - - -
Punto de ContactoPropósito
Git-Hub-README-Button-3xConsultas comerciales y retroalimentación del producto
Git-Hub-README-Button-2xContribuciones, problemas y solicitudes de características
- ## Historial de Estrellas [![Gráfico de Historial de Estrellas](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) @@ -256,4 +241,4 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En ## Licencia -Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. \ No newline at end of file +Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. diff --git a/README_FR.md b/README_FR.md index 17a08812848443..681d596749c9e7 100644 --- a/README_FR.md +++ b/README_FR.md @@ -4,7 +4,7 @@ Dify Cloud · Auto-hébergement · Documentation · - Planifier une démo + Demande d’entreprise (en anglais seulement)

@@ -29,14 +29,16 @@

- Commits le mois dernier - Commits le mois dernier - Commits le mois dernier - Commits le mois dernier - Commits le mois dernier - Commits le mois dernier - Commits le mois dernier + README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية Türkçe README + README Tiếng Việt

# @@ -70,7 +72,7 @@ Dify est une plateforme de développement d'applications LLM open source. Son in **5. Capac ités d'agent**: - Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DELL·E, Stable Diffusion et WolframAlpha. + Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. **6. LLMOps**: Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. @@ -156,7 +158,7 @@ Lancez rapidement Dify dans votre environnement avec ce [guide de démarrage](#q Utilisez notre [documentation](https://docs.dify.ai) pour plus de références et des instructions plus détaillées. - **Dify pour les entreprises / organisations
** -Nous proposons des fonctionnalités supplémentaires adaptées aux entreprises. [Planifiez une réunion avec nous](https://cal.com/guchenhe/30min) ou [envoyez-nous un e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) pour discuter des besoins de l'entreprise.
+Nous proposons des fonctionnalités supplémentaires adaptées aux entreprises. [Envoyez-nous un e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) pour discuter des besoins de l'entreprise.
> Pour les startups et les petites entreprises utilisant AWS, consultez [Dify Premium sur AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) et déployez-le dans votre propre VPC AWS en un clic. C'est une offre AMI abordable avec la possibilité de créer des applications avec un logo et une marque personnalisés. @@ -226,23 +228,6 @@ Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur le * [Discord](https://discord.gg/FngNHpbcY7). Meilleur pour: partager vos applications et passer du temps avec la communauté. * [Twitter](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. -Ou, planifiez directement une réunion avec un membre de l'équipe: - - - - - - - - - - - - - - -
Point de contactObjectif
Git-Hub-README-Button-3xDemandes commerciales & retours produit
Git-Hub-README-Button-2xContributions, problèmes & demandes de fonctionnalités
- ## Historique des étoiles [![Graphique de l'historique des étoiles](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_JA.md b/README_JA.md index 5828379a74cf0f..e6a8621e7baae5 100644 --- a/README_JA.md +++ b/README_JA.md @@ -4,7 +4,7 @@ Dify Cloud · セルフホスティング · ドキュメント · - デモの予約 + 企業のお問い合わせ(英語のみ)

@@ -29,14 +29,16 @@

- 先月のコミット - 先月のコミット - 先月のコミット - 先月のコミット - 先月のコミット - 先月のコミット - 先月のコミット + README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية Türkçe README + README Tiếng Việt

# @@ -69,7 +71,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ ドキュメントの取り込みから検索までをカバーする広範なRAG機能ができます。ほかにもPDF、PPT、その他の一般的なドキュメントフォーマットからのテキスト抽出のサーポイントも提供します。 **5. エージェント機能**: - LLM Function CallingやReActに基づくエージェントの定義が可能で、AIエージェント用のプリビルトまたはカスタムツールを追加できます。Difyには、Google検索、DELL·E、Stable Diffusion、WolframAlphaなどのAIエージェント用の50以上の組み込みツールが提供します。 + LLM Function CallingやReActに基づくエージェントの定義が可能で、AIエージェント用のプリビルトまたはカスタムツールを追加できます。Difyには、Google検索、DALL·E、Stable Diffusion、WolframAlphaなどのAIエージェント用の50以上の組み込みツールが提供します。 **6. LLMOps**: アプリケーションのログやパフォーマンスを監視と分析し、生産のデータと注釈に基づいて、プロンプト、データセット、モデルを継続的に改善できます。 @@ -155,7 +157,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ 詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 - **企業/組織向けのDify
** -企業中心の機能を提供しています。[こちらからミーティングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)して企業のニーズについて相談してください。
+企業中心の機能を提供しています。[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)して企業のニーズについて相談してください。
> AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングどして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。 @@ -225,28 +227,6 @@ docker compose up -d * [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 * [Twitter](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 -または、直接チームメンバーとミーティングをスケジュール: - - - - - - - - - - - - - - - - - - -
連絡先目的
ミーティング無料の30分間のミーティングをスケジュール
技術サポート技術的な問題やサポートに関する質問
営業担当法人ライセンスに関するお問い合わせ
## ライセンス diff --git a/README_KL.md b/README_KL.md index 64d2d2485886fb..04620d42bbec8a 100644 --- a/README_KL.md +++ b/README_KL.md @@ -4,7 +4,7 @@ Dify Cloud · Self-hosting · Documentation · - Schedule demo + Commercial enquiries

@@ -29,14 +29,16 @@

- Commits last month - Commits last month - Commits last month - Commits last month - Commits last month - Commits last month - Commits last month + README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية Türkçe README + README Tiếng Việt

# @@ -68,7 +70,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. **5. Agent capabilities**: - You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DELL·E, Stable Diffusion and WolframAlpha. + You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. **6. LLMOps**: Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. @@ -156,7 +158,7 @@ Quickly get Dify running in your environment with this [starter guide](#quick-st Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - **Dify for Enterprise / Organizations
** -We provide additional enterprise-centric features. [Schedule a meeting with us](https://cal.com/guchenhe/30min) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
+We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding. @@ -228,23 +230,6 @@ At the same time, please consider supporting Dify by sharing it on social media * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. * [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. -Or, schedule a meeting directly with a team member: - - - - - - - - - - - - - - -
Point of ContactPurpose
Git-Hub-README-Button-3xBusiness enquiries & product feedback
Git-Hub-README-Button-2xContributions, issues & feature requests
- ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) @@ -256,4 +241,4 @@ To protect your privacy, please avoid posting security issues on GitHub. Instead ## License -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. \ No newline at end of file +This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. diff --git a/README_KR.md b/README_KR.md index 2d7db3df4cac67..a5f3bc68d04d74 100644 --- a/README_KR.md +++ b/README_KR.md @@ -4,7 +4,7 @@ Dify 클라우드 · 셀프-호스팅 · 문서 · - 기업 문의 + 기업 문의 (영어만 가능)

@@ -35,8 +35,10 @@ README en Español README en Français README tlhIngan Hol - 한국어 README + README in Korean + README بالعربية Türkçe README + README Tiếng Việt

@@ -64,7 +66,7 @@ 문서 수집부터 검색까지 모든 것을 다루며, PDF, PPT 및 기타 일반적인 문서 형식에서 텍스트 추출을 위한 기본 지원이 포함되어 있는 광범위한 RAG 기능을 제공합니다. **5. 에이전트 기능**: - LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에이전트에 대해 사전 구축된 도구나 사용자 정의 도구를 추가할 수 있습니다. Dify는 Google Search, DELL·E, Stable Diffusion, WolframAlpha 등 AI 에이전트를 위한 50개 이상의 내장 도구를 제공합니다. + LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에이전트에 대해 사전 구축된 도구나 사용자 정의 도구를 추가할 수 있습니다. Dify는 Google Search, DALL·E, Stable Diffusion, WolframAlpha 등 AI 에이전트를 위한 50개 이상의 내장 도구를 제공합니다. **6. LLMOps**: 시간 경과에 따른 애플리케이션 로그와 성능을 모니터링하고 분석합니다. 생산 데이터와 주석을 기반으로 프롬프트, 데이터세트, 모델을 지속적으로 개선할 수 있습니다. @@ -149,7 +151,7 @@ 추가 참조 및 더 심층적인 지침은 [문서](https://docs.dify.ai)를 사용하세요. - **기업 / 조직을 위한 Dify
** - 우리는 추가적인 기업 중심 기능을 제공합니다. 당사와 [미팅일정](https://cal.com/guchenhe/30min)을 잡거나 [이메일 보내기](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)를 통해 기업 요구 사항을 논의하십시오.
+ 우리는 추가적인 기업 중심 기능을 제공합니다. 잡거나 [이메일 보내기](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)를 통해 기업 요구 사항을 논의하십시오.
> AWS를 사용하는 스타트업 및 중소기업의 경우 [AWS Marketplace에서 Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)을 확인하고 한 번의 클릭으로 자체 AWS VPC에 배포하십시오. 맞춤형 로고와 브랜딩이 포함된 앱을 생성할 수 있는 옵션이 포함된 저렴한 AMI 제품입니다. @@ -218,22 +220,6 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 * [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. * [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. -또는 팀원과 직접 미팅을 예약하세요: - - - - - - - - - - - - - - -
연락처목적
Git-Hub-README-Button-3x비즈니스 문의 및 제품 피드백
Git-Hub-README-Button-2x기여, 이슈 및 기능 요청
## Star 히스토리 diff --git a/README_TR.md b/README_TR.md index 2ae7d440a81373..54b6db3f823717 100644 --- a/README_TR.md +++ b/README_TR.md @@ -4,7 +4,7 @@ Dify Bulut · Kendi Sunucunuzda Barındırma · Dokümantasyon · - Kurumsal Sorgu + Yalnızca İngilizce: Kurumsal Sorgulama

@@ -38,6 +38,7 @@ README in Korean README بالعربية Türkçe README + README Tiếng Việt

@@ -155,7 +156,7 @@ Bu [başlangıç kılavuzu](#quick-start) ile Dify'ı kendi ortamınızda hızl Daha fazla referans ve detaylı talimatlar için [dokümantasyonumuzu](https://docs.dify.ai) kullanın. - **Kurumlar / organizasyonlar için Dify
** -Ek kurumsal odaklı özellikler sunuyoruz. Kurumsal ihtiyaçları görüşmek için [bizimle bir toplantı planlayın](https://cal.com/guchenhe/30min) veya [bize bir e-posta gönderin](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry).
+Ek kurumsal odaklı özellikler sunuyoruz. Kurumsal ihtiyaçları görüşmek için [bize bir e-posta gönderin](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry).
> AWS kullanan startuplar ve küçük işletmeler için, [AWS Marketplace'deki Dify Premium'a](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) göz atın ve tek tıklamayla kendi AWS VPC'nize dağıtın. Bu, özel logo ve marka ile uygulamalar oluşturma seçeneğine sahip uygun fiyatlı bir AMI teklifdir. ## Güncel Kalma @@ -223,23 +224,6 @@ Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda p * [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. * [Twitter](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. -Veya doğrudan bir ekip üyesiyle toplantı planlayın: - - - - - - - - - - - - - - -
İletişim NoktasıAmaç
Git-Hub-README-Button-3xİş sorgulamaları & ürün geri bildirimleri
Git-Hub-README-Button-2xKatkılar, sorunlar & özellik istekleri
- ## Star history [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_VI.md b/README_VI.md new file mode 100644 index 00000000000000..6d4035eceb06de --- /dev/null +++ b/README_VI.md @@ -0,0 +1,234 @@ +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) + +

+ Dify Cloud · + Tự triển khai · + Tài liệu · + Yêu cầu doanh nghiệp +

+ +

+ + Static Badge + + Static Badge + + chat trên Discord + + theo dõi trên Twitter + + Docker Pulls + + Commits tháng trước + + Vấn đề đã đóng + + Bài thảo luận +

+ +

+ README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt +

+ + +Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Giao diện trực quan kết hợp quy trình làm việc AI, mô hình RAG, khả năng tác nhân, quản lý mô hình, tính năng quan sát và hơn thế nữa, cho phép bạn nhanh chóng chuyển từ nguyên mẫu sang sản phẩm. Đây là danh sách các tính năng cốt lõi: +

+ +**1. Quy trình làm việc**: + Xây dựng và kiểm tra các quy trình làm việc AI mạnh mẽ trên một canvas trực quan, tận dụng tất cả các tính năng sau đây và hơn thế nữa. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. Hỗ trợ mô hình toàn diện**: + Tích hợp liền mạch với hàng trăm mô hình LLM độc quyền / mã nguồn mở từ hàng chục nhà cung cấp suy luận và giải pháp tự lưu trữ, bao gồm GPT, Mistral, Llama3, và bất kỳ mô hình tương thích API OpenAI nào. Danh sách đầy đủ các nhà cung cấp mô hình được hỗ trợ có thể được tìm thấy [tại đây](https://docs.dify.ai/getting-started/readme/model-providers). + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. IDE Prompt**: + Giao diện trực quan để tạo prompt, so sánh hiệu suất mô hình và thêm các tính năng bổ sung như chuyển văn bản thành giọng nói cho một ứng dụng dựa trên trò chuyện. + +**4. Mô hình RAG**: + Khả năng RAG mở rộng bao gồm mọi thứ từ nhập tài liệu đến truy xuất, với hỗ trợ sẵn có cho việc trích xuất văn bản từ PDF, PPT và các định dạng tài liệu phổ biến khác. + +**5. Khả năng tác nhân**: + Bạn có thể định nghĩa các tác nhân dựa trên LLM Function Calling hoặc ReAct, và thêm các công cụ được xây dựng sẵn hoặc tùy chỉnh cho tác nhân. Dify cung cấp hơn 50 công cụ tích hợp sẵn cho các tác nhân AI, như Google Search, DALL·E, Stable Diffusion và WolframAlpha. + +**6. LLMOps**: + Giám sát và phân tích nhật ký và hiệu suất ứng dụng theo thời gian. Bạn có thể liên tục cải thiện prompt, bộ dữ liệu và mô hình dựa trên dữ liệu sản xuất và chú thích. + +**7. Backend-as-a-Service**: + Tất cả các dịch vụ của Dify đều đi kèm với các API tương ứng, vì vậy bạn có thể dễ dàng tích hợp Dify vào logic kinh doanh của riêng mình. + + +## So sánh tính năng + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Tính năngDify.AILangChainFlowiseOpenAI Assistants API
Phương pháp lập trìnhHướng API + Ứng dụngMã PythonHướng ứng dụngHướng API
LLMs được hỗ trợĐa dạng phong phúĐa dạng phong phúĐa dạng phong phúChỉ OpenAI
RAG Engine
Agent
Quy trình làm việc
Khả năng quan sát
Tính năng doanh nghiệp (SSO/Kiểm soát truy cập)
Triển khai cục bộ
+ +## Sử dụng Dify + +- **Cloud
** +Chúng tôi lưu trữ dịch vụ [Dify Cloud](https://dify.ai) cho bất kỳ ai muốn thử mà không cần cài đặt. Nó cung cấp tất cả các khả năng của phiên bản tự triển khai và bao gồm 200 lượt gọi GPT-4 miễn phí trong gói sandbox. + +- **Tự triển khai Dify Community Edition
** +Nhanh chóng chạy Dify trong môi trường của bạn với [hướng dẫn bắt đầu](#quick-start) này. +Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn. + +- **Dify cho doanh nghiệp / tổ chức
** +Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp.
+ > Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh. + + +## Luôn cập nhật + +Yêu thích Dify trên GitHub và được thông báo ngay lập tức về các bản phát hành mới. + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## Bắt đầu nhanh +> Trước khi cài đặt Dify, hãy đảm bảo máy của bạn đáp ứng các yêu cầu hệ thống tối thiểu sau: +> +>- CPU >= 2 Core +>- RAM >= 4GB + +
+ +Cách dễ nhất để khởi động máy chủ Dify là chạy tệp [docker-compose.yml](docker/docker-compose.yaml) của chúng tôi. Trước khi chạy lệnh cài đặt, hãy đảm bảo rằng [Docker](https://docs.docker.com/get-docker/) và [Docker Compose](https://docs.docker.com/compose/install/) đã được cài đặt trên máy của bạn: + +```bash +cd docker +cp .env.example .env +docker compose up -d +``` + +Sau khi chạy, bạn có thể truy cập bảng điều khiển Dify trong trình duyệt của bạn tại [http://localhost/install](http://localhost/install) và bắt đầu quá trình khởi tạo. + +> Nếu bạn muốn đóng góp cho Dify hoặc phát triển thêm, hãy tham khảo [hướng dẫn triển khai từ mã nguồn](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) của chúng tôi + +## Các bước tiếp theo + +Nếu bạn cần tùy chỉnh cấu hình, vui lòng tham khảo các nhận xét trong tệp [.env.example](docker/.env.example) của chúng tôi và cập nhật các giá trị tương ứng trong tệp `.env` của bạn. Ngoài ra, bạn có thể cần điều chỉnh tệp `docker-compose.yaml`, chẳng hạn như thay đổi phiên bản hình ảnh, ánh xạ cổng hoặc gắn kết khối lượng, dựa trên môi trường triển khai cụ thể và yêu cầu của bạn. Sau khi thực hiện bất kỳ thay đổi nào, vui lòng chạy lại `docker-compose up -d`. Bạn có thể tìm thấy danh sách đầy đủ các biến môi trường có sẵn [tại đây](https://docs.dify.ai/getting-started/install-self-hosted/environments). + +Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có các [Helm Charts](https://helm.sh/) và tệp YAML do cộng đồng đóng góp cho phép Dify được triển khai trên Kubernetes. + +- [Helm Chart bởi @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) +- [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes) + +#### Sử dụng Terraform để Triển khai + +##### Azure Global +Triển khai Dify lên Azure chỉ với một cú nhấp chuột bằng cách sử dụng [terraform](https://www.terraform.io/). +- [Azure Terraform bởi @nikawang](https://github.com/nikawang/dify-azure-terraform) + +## Đóng góp + +Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. +Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị. + + +> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. + +**Người đóng góp** + + + + + +## Cộng đồng & liên hệ + +* [Thảo luận GitHub](https://github.com/langgenius/dify/discussions). Tốt nhất cho: chia sẻ phản hồi và đặt câu hỏi. +* [Vấn đề GitHub](https://github.com/langgenius/dify/issues). Tốt nhất cho: lỗi bạn gặp phải khi sử dụng Dify.AI và đề xuất tính năng. Xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. +* [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. +* [Twitter](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. + +## Lịch sử Yêu thích + +[![Biểu đồ Lịch sử Yêu thích](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) + +## Tiết lộ bảo mật + +Để bảo vệ quyền riêng tư của bạn, vui lòng tránh đăng các vấn đề bảo mật trên GitHub. Thay vào đó, hãy gửi câu hỏi của bạn đến security@dify.ai và chúng tôi sẽ cung cấp cho bạn câu trả lời chi tiết hơn. + +## Giấy phép + +Kho lưu trữ này có sẵn theo [Giấy phép Mã nguồn Mở Dify](LICENSE), về cơ bản là Apache 2.0 với một vài hạn chế bổ sung. \ No newline at end of file diff --git a/api/.env.example b/api/.env.example index cf3a0f302d60cc..e41e2271d52d24 100644 --- a/api/.env.example +++ b/api/.env.example @@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=your-endpoint ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region - +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string @@ -130,6 +131,12 @@ TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +# ElasticSearch configuration +ELASTICSEARCH_HOST=127.0.0.1 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic + # PGVECTO_RS configuration PGVECTO_RS_HOST=localhost PGVECTO_RS_PORT=5431 @@ -241,8 +248,8 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60 HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 HTTP_REQUEST_MAX_READ_TIMEOUT=600 HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 -HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB -HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # Log file path LOG_FILE= @@ -261,4 +268,13 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration -CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= diff --git a/.idea/icon.png b/api/.idea/icon.png similarity index 100% rename from .idea/icon.png rename to api/.idea/icon.png diff --git a/.idea/vcs.xml b/api/.idea/vcs.xml similarity index 88% rename from .idea/vcs.xml rename to api/.idea/vcs.xml index ae8b1755c52a17..b7af618884ac3b 100644 --- a/.idea/vcs.xml +++ b/api/.idea/vcs.xml @@ -12,5 +12,6 @@ + - \ No newline at end of file + diff --git a/.vscode/launch.json b/api/.vscode/launch.json.example similarity index 83% rename from .vscode/launch.json rename to api/.vscode/launch.json.example index e4eb6aef932faf..e9f8e42dd5341c 100644 --- a/.vscode/launch.json +++ b/api/.vscode/launch.json.example @@ -5,8 +5,8 @@ "name": "Python: Flask", "type": "debugpy", "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", "envFile": ".env", "module": "flask", "justMyCode": true, @@ -18,15 +18,15 @@ "args": [ "run", "--host=0.0.0.0", - "--port=5001", + "--port=5001" ] }, { "name": "Python: Celery", "type": "debugpy", "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", "module": "celery", "justMyCode": true, "envFile": ".env", diff --git a/api/Dockerfile b/api/Dockerfile index 55776f80e136c8..82b89ad77ba1ac 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -5,6 +5,10 @@ WORKDIR /app/api # Install Poetry ENV POETRY_VERSION=1.8.3 + +# if you located in China, you can use aliyun mirror to speed up +# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/ + RUN pip install --no-cache-dir poetry==${POETRY_VERSION} # Configure Poetry @@ -12,9 +16,13 @@ ENV POETRY_CACHE_DIR=/tmp/poetry_cache ENV POETRY_NO_INTERACTION=1 ENV POETRY_VIRTUALENVS_IN_PROJECT=true ENV POETRY_VIRTUALENVS_CREATE=true +ENV POETRY_REQUESTS_TIMEOUT=15 FROM base AS packages +# if you located in China, you can use aliyun mirror to speed up +# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources + RUN apt-get update \ && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev @@ -41,8 +49,14 @@ ENV TZ=UTC WORKDIR /app/api RUN apt-get update \ - && apt-get install -y --no-install-recommends curl wget vim nodejs ffmpeg libgmp-dev libmpfr-dev libmpc-dev \ - && apt-get autoremove \ + && apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ + # if you located in China, you can use aliyun mirror to speed up + # && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \ + && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ + && apt-get update \ + # For Security + && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ + && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* # Copy Python environment and packages @@ -50,6 +64,9 @@ ENV VIRTUAL_ENV=/app/api/.venv COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" +# Download nltk data +RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" + # Copy source code COPY . /app/api/ diff --git a/api/app.py b/api/app.py index 50441cb81da1f4..ad219ca0d67459 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,6 @@ import os -if os.environ.get("DEBUG", "false").lower() != 'true': +if os.environ.get("DEBUG", "false").lower() != "true": from gevent import monkey monkey.patch_all() @@ -57,7 +57,7 @@ if os.name == "nt": os.system('tzutil /s "UTC"') else: - os.environ['TZ'] = 'UTC' + os.environ["TZ"] = "UTC" time.tzset() @@ -70,13 +70,14 @@ class DifyApp(Flask): # ------------- -config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first +config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first # ---------------------------- # Application Factory Function # ---------------------------- + def create_flask_app_with_configs() -> Flask: """ create a raw flask app @@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask: elif isinstance(value, int | float | bool): os.environ[key] = str(value) elif value is None: - os.environ[key] = '' + os.environ[key] = "" return dify_app @@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask: def create_app() -> Flask: app = create_flask_app_with_configs() - app.secret_key = app.config['SECRET_KEY'] + app.secret_key = app.config["SECRET_KEY"] log_handlers = None - log_file = app.config.get('LOG_FILE') + log_file = app.config.get("LOG_FILE") if log_file: log_dir = os.path.dirname(log_file) os.makedirs(log_dir, exist_ok=True) @@ -111,23 +112,24 @@ def create_app() -> Flask: RotatingFileHandler( filename=log_file, maxBytes=1024 * 1024 * 1024, - backupCount=5 + backupCount=5, ), - logging.StreamHandler(sys.stdout) + logging.StreamHandler(sys.stdout), ] logging.basicConfig( - level=app.config.get('LOG_LEVEL'), - format=app.config.get('LOG_FORMAT'), - datefmt=app.config.get('LOG_DATEFORMAT'), + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), handlers=log_handlers, - force=True + force=True, ) - log_tz = app.config.get('LOG_TZ') + log_tz = app.config.get("LOG_TZ") if log_tz: from datetime import datetime import pytz + timezone = pytz.timezone(log_tz) def time_converter(seconds): @@ -162,24 +164,24 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ['console', 'inner_api']: + if request.blueprint not in ["console", "inner_api"]: return None # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get('Authorization', '') + auth_header = request.headers.get("Authorization", "") if not auth_header: - auth_token = request.args.get('_token') + auth_token = request.args.get("_token") if not auth_token: - raise Unauthorized('Invalid Authorization token.') + raise Unauthorized("Invalid Authorization token.") else: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(auth_token) - user_id = decoded.get('user_id') + user_id = decoded.get("user_id") account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) if account: @@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login): @login_manager.unauthorized_handler def unauthorized_handler(): """Handle unauthorized requests.""" - return Response(json.dumps({ - 'code': 'unauthorized', - 'message': "Unauthorized." - }), status=401, content_type="application/json") + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) # register blueprint routers @@ -204,38 +207,36 @@ def register_blueprints(app): from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp - CORS(service_api_bp, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) app.register_blueprint(service_api_bp) - CORS(web_bp, - resources={ - r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(web_bp) - CORS(console_app_bp, - resources={ - r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(console_app_bp) - CORS(files_bp, - allow_headers=['Content-Type'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) @@ -245,29 +246,29 @@ def register_blueprints(app): app = create_app() celery = app.extensions["celery"] -if app.config.get('TESTING'): +if app.config.get("TESTING"): print("App is running in TESTING mode") @app.after_request def after_request(response): """Add Version headers to the response.""" - response.set_cookie('remember_token', '', expires=0) - response.headers.add('X-Version', app.config['CURRENT_VERSION']) - response.headers.add('X-Env', app.config['DEPLOY_ENV']) + response.set_cookie("remember_token", "", expires=0) + response.headers.add("X-Version", app.config["CURRENT_VERSION"]) + response.headers.add("X-Env", app.config["DEPLOY_ENV"]) return response -@app.route('/health') +@app.route("/health") def health(): - return Response(json.dumps({ - 'pid': os.getpid(), - 'status': 'ok', - 'version': app.config['CURRENT_VERSION'] - }), status=200, content_type="application/json") + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + status=200, + content_type="application/json", + ) -@app.route('/threads') +@app.route("/threads") def threads(): num_threads = threading.active_count() threads = threading.enumerate() @@ -278,32 +279,34 @@ def threads(): thread_id = thread.ident is_alive = thread.is_alive() - thread_list.append({ - 'name': thread_name, - 'id': thread_id, - 'is_alive': is_alive - }) + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) return { - 'pid': os.getpid(), - 'thread_num': num_threads, - 'threads': thread_list + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, } -@app.route('/db-pool-stat') +@app.route("/db-pool-stat") def pool_stat(): engine = db.engine return { - 'pid': os.getpid(), - 'pool_size': engine.pool.size(), - 'checked_in_connections': engine.pool.checkedin(), - 'checked_out_connections': engine.pool.checkedout(), - 'overflow_connections': engine.pool.overflow(), - 'connection_timeout': engine.pool.timeout(), - 'recycle_time': db.engine.pool._recycle + "pid": os.getpid(), + "pool_size": engine.pool.size(), + "checked_in_connections": engine.pool.checkedin(), + "checked_out_connections": engine.pool.checkedout(), + "overflow_connections": engine.pool.overflow(), + "connection_timeout": engine.pool.timeout(), + "recycle_time": db.engine.pool._recycle, } -if __name__ == '__main__': - app.run(host='0.0.0.0', port=5001) +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5001) diff --git a/api/commands.py b/api/commands.py index c7ffb47b512246..3bf8bc0ecc45f2 100644 --- a/api/commands.py +++ b/api/commands.py @@ -27,32 +27,29 @@ from services.account_service import RegisterService, TenantService -@click.command('reset-password', help='Reset the account password.') -@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') -@click.option('--new-password', prompt=True, help='the new password.') -@click.option('--password-confirm', prompt=True, help='the new password confirm.') +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset") +@click.option("--new-password", prompt=True, help="the new password.") +@click.option("--password-confirm", prompt=True, help="the new password confirm.") def reset_password(email, new_password, password_confirm): """ Reset password of owner account Only available in SELF_HOSTED mode """ if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + click.echo(click.style("sorry. The two passwords do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo( - click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red')) + click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) return # generate password salt @@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style('Congratulations! Password has been reset.', fg='green')) + click.echo(click.style("Congratulations! Password has been reset.", fg="green")) -@click.command('reset-email', help='Reset the account email.') -@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') -@click.option('--new-email', prompt=True, help='the new email.') -@click.option('--email-confirm', prompt=True, help='the new email confirm.') +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset") +@click.option("--new-email", prompt=True, help="the new email.") +@click.option("--email-confirm", prompt=True, help="the new email confirm.") def reset_email(email, new_email, email_confirm): """ Replace account email :return: """ if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: email_validate(new_email) except: - click.echo( - click.style('sorry. {} is not a valid email. '.format(email), fg='red')) + click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) return account.email = new_email db.session.commit() - click.echo(click.style('Congratulations!, email has been reset.', fg='green')) - - -@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. ' - 'After the reset, all LLM credentials will become invalid, ' - 'requiring re-entry.' - 'Only support SELF_HOSTED mode.') -@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?' - ' this operation cannot be rolled back!', fg='red')) + click.echo(click.style("Congratulations!, email has been reset.", fg="green")) + + +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" + ) +) def reset_encrypt_key_pair(): """ Reset the encrypted key pair of workspace for encrypt LLM credentials. After the reset, all LLM credentials will become invalid, requiring re-entry. Only support SELF_HOSTED mode. """ - if dify_config.EDITION != 'SELF_HOSTED': - click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red')) + if dify_config.EDITION != "SELF_HOSTED": + click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red')) + click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red")) return tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete() + db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() - click.echo(click.style('Congratulations! ' - 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) + click.echo( + click.style( + "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + fg="green", + ) + ) -@click.command('vdb-migrate', help='migrate vector db.') -@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.') +@click.command("vdb-migrate", help="migrate vector db.") +@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ['knowledge', 'all']: + if scope in ["knowledge", "all"]: migrate_knowledge_vector_database() - if scope in ['annotation', 'all']: + if scope in ["annotation", "all"]: migrate_annotation_vector_database() @@ -146,7 +150,7 @@ def migrate_annotation_vector_database(): """ Migrate annotation datas to target vector database . """ - click.echo(click.style('Start migrate annotation data.', fg='green')) + click.echo(click.style("Start migrate annotation data.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -154,98 +158,103 @@ def migrate_annotation_vector_database(): while True: try: # get apps info - apps = db.session.query(App).filter( - App.status == 'normal' - ).order_by(App.created_at.desc()).paginate(page=page, per_page=50) + apps = ( + db.session.query(App) + .filter(App.status == "normal") + .order_by(App.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for app in apps: total_count = total_count + 1 - click.echo(f'Processing the {total_count} app {app.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create app annotation index: {}'.format(app.id)) - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app.id - ).first() + click.echo("Create app annotation index: {}".format(app.id)) + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo('App annotation setting is disabled: {}'.format(app.id)) + click.echo("App annotation setting is disabled: {}".format(app.id)) continue # get dataset_collection_binding info - dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter( - DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id - ).first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) if not dataset_collection_binding: - click.echo('App annotation collection binding is not exist: {}'.format(app.id)) + click.echo("App annotation collection binding is not exist: {}".format(app.id)) continue annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) documents = [] if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app.id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) click.echo(f"Start to migrate annotation, app_id: {app.id}.") try: vector.delete() - click.echo( - click.style(f'Successfully delete vector index for app: {app.id}.', - fg='green')) + click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green")) except Exception as e: - click.echo( - click.style(f'Failed to delete vector index for app {app.id}.', - fg='red')) + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} annotations for app {app.id}.', - fg='green')) - vector.create(documents) click.echo( - click.style(f'Successfully created vector index for app {app.id}.', fg='green')) + click.style( + f"Start to created vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) + vector.create(documents) + click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green")) except Exception as e: - click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) raise e - click.echo(f'Successfully migrated app annotation {app.id}.') + click.echo(f"Successfully migrated app annotation {app.id}.") create_count += 1 except Exception as e: click.echo( - click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style( + "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red" + ) + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.", + fg="green", + ) + ) def migrate_knowledge_vector_database(): """ Migrate vector database datas to target vector database . """ - click.echo(click.style('Start migrate vector db.', fg='green')) + click.echo(click.style("Start migrate vector db.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -253,87 +262,77 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + datasets = ( + db.session.query(Dataset) + .filter(Dataset.indexing_technique == "high_quality") + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for dataset in datasets: total_count = total_count + 1 - click.echo(f'Processing the {total_count} dataset {dataset.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} dataset {dataset.id}. " + + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create dataset vdb index: {}'.format(dataset.id)) + click.echo("Create dataset vdb index: {}".format(dataset.id)) if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] == vector_type: + if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 continue - collection_name = '' + collection_name = "" if vector_type == VectorType.WEAVIATE: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.WEAVIATE, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.QDRANT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.MILVUS: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.MILVUS, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.RELYT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'relyt', - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.TENCENT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.TENCENT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.PGVECTOR: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.PGVECTOR, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.OPENSEARCH: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.ANALYTICDB: @@ -341,9 +340,14 @@ def migrate_knowledge_vector_database(): collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.ELASTICSEARCH: + dataset_id = dataset.id + index_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -353,29 +357,41 @@ def migrate_knowledge_vector_database(): try: vector.delete() click.echo( - click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.', - fg='green')) + click.style( + f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green" + ) + ) except Exception as e: click.echo( - click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.', - fg='red')) + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" + ) + ) raise e - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .all() + ) for segment in segments: document = Document( @@ -385,7 +401,7 @@ def migrate_knowledge_vector_database(): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -393,37 +409,43 @@ def migrate_knowledge_vector_database(): if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.', - fg='green')) + click.echo( + click.style( + f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + fg="green", + ) + ) vector.create(documents) click.echo( - click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green')) + click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green") + ) except Exception as e: - click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) raise e db.session.add(dataset) db.session.commit() - click.echo(f'Successfully migrated dataset {dataset.id}.') + click.echo(f"Successfully migrated dataset {dataset.id}.") create_count += 1 except Exception as e: db.session.rollback() click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green" + ) + ) -@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") def convert_to_agent_apps(): """ Convert Agent Assistant to Agent App. """ - click.echo(click.style('Start convert to agent apps.', fg='green')) + click.echo(click.style("Start convert to agent apps.", fg="green")) proceeded_app_ids = [] @@ -458,7 +480,7 @@ def convert_to_agent_apps(): break for app in apps: - click.echo('Converting app: {}'.format(app.id)) + click.echo("Converting app: {}".format(app.id)) try: app.mode = AppMode.AGENT_CHAT.value @@ -470,137 +492,142 @@ def convert_to_agent_apps(): ) db.session.commit() - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + click.echo(click.style("Converted app: {}".format(app.id), fg="green")) except Exception as e: - click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) + click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) - click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) -@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.') -@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') +@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.") def add_qdrant_doc_id_index(field: str): - click.echo(click.style('Start add qdrant doc_id index.', fg='green')) + click.echo(click.style("Start add qdrant doc_id index.", fg="green")) vector_type = dify_config.VECTOR_STORE if vector_type != "qdrant": - click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) + click.echo(click.style("Sorry, only support qdrant vector store.", fg="red")) return create_count = 0 try: bindings = db.session.query(DatasetCollectionBinding).all() if not bindings: - click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red')) + click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red")) return import qdrant_client from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + for binding in bindings: if dify_config.QDRANT_URL is None: - raise ValueError('Qdrant url is required.') + raise ValueError("Qdrant url is required.") qdrant_config = QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, root_path=current_app.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) # create payload index - client.create_payload_index(binding.collection_name, field, - field_schema=PayloadSchemaType.KEYWORD) + client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 except UnexpectedResponse as e: # Collection does not exist, so return if e.status_code == 404: - click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red") + ) continue # Some other error occurred, so re-raise the exception else: - click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style( + f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red" + ) + ) except Exception as e: - click.echo(click.style('Failed to create qdrant client.', fg='red')) + click.echo(click.style("Failed to create qdrant client.", fg="red")) - click.echo( - click.style(f'Congratulations! Create {create_count} collection indexes.', - fg='green')) + click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green")) -@click.command('create-tenant', help='Create account and tenant.') -@click.option('--email', prompt=True, help='The email address of the tenant account.') -@click.option('--language', prompt=True, help='Account language, default: en-US.') -def create_tenant(email: str, language: Optional[str] = None): +@click.command("create-tenant", help="Create account and tenant.") +@click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--name", prompt=True, help="The workspace name of the tenant account.") +@click.option("--language", prompt=True, help="Account language, default: en-US.") +def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): """ Create tenant account """ if not email: - click.echo(click.style('Sorry, email is required.', fg='red')) + click.echo(click.style("Sorry, email is required.", fg="red")) return # Create account email = email.strip() - if '@' not in email: - click.echo(click.style('Sorry, invalid email address.', fg='red')) + if "@" not in email: + click.echo(click.style("Sorry, invalid email address.", fg="red")) return - account_name = email.split('@')[0] + account_name = email.split("@")[0] if language not in languages: - language = 'en-US' + language = "en-US" + + name = name.strip() # generate random password new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register( - email=email, - name=account_name, - password=new_password, - language=language - ) + account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account, name) - click.echo(click.style('Congratulations! Account and tenant created.\n' - 'Account: {}\nPassword: {}'.format(email, new_password), fg='green')) + click.echo( + click.style( + "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), + fg="green", + ) + ) -@click.command('upgrade-db', help='upgrade the database') +@click.command("upgrade-db", help="upgrade the database") def upgrade_db(): - click.echo('Preparing database migration...') - lock = redis_client.lock(name='db_upgrade_lock', timeout=60) + click.echo("Preparing database migration...") + lock = redis_client.lock(name="db_upgrade_lock", timeout=60) if lock.acquire(blocking=False): try: - click.echo(click.style('Start database migration.', fg='green')) + click.echo(click.style("Start database migration.", fg="green")) # run db migration import flask_migrate + flask_migrate.upgrade() - click.echo(click.style('Database migration successful!', fg='green')) + click.echo(click.style("Database migration successful!", fg="green")) except Exception as e: - logging.exception(f'Database migration failed, error: {e}') + logging.exception(f"Database migration failed, error: {e}") finally: lock.release() else: - click.echo('Database migration skipped') + click.echo("Database migration skipped") -@click.command('fix-app-site-missing', help='Fix app related site missing issue.') +@click.command("fix-app-site-missing", help="Fix app related site missing issue.") def fix_app_site_missing(): """ Fix app related site missing issue. """ - click.echo(click.style('Start fix app related site missing issue.', fg='green')) + click.echo(click.style("Start fix app related site missing issue.", fg="green")) failed_app_ids = [] while True: @@ -631,15 +658,14 @@ def fix_app_site_missing(): app_was_created.send(app, account=account) except Exception as e: failed_app_ids.append(app_id) - click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red')) - logging.exception(f'Fix app related site missing issue failed, error: {e}') + click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red")) + logging.exception(f"Fix app related site missing issue failed, error: {e}") continue if not processed_count: break - - click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green')) + click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green")) def register_commands(app): diff --git a/api/configs/__init__.py b/api/configs/__init__.py index c0e28c34e1e1ea..3a172601c96382 100644 --- a/api/configs/__init__.py +++ b/api/configs/__init__.py @@ -1,3 +1,3 @@ from .app_config import DifyConfig -dify_config = DifyConfig() \ No newline at end of file +dify_config = DifyConfig() diff --git a/api/configs/app_config.py b/api/configs/app_config.py index a5a4fc788d0d19..61de73c8689f8b 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -1,4 +1,3 @@ -from pydantic import Field, computed_field from pydantic_settings import SettingsConfigDict from configs.deploy import DeploymentConfig @@ -12,58 +11,28 @@ class DifyConfig( # Packaging info PackagingInfo, - # Deployment configs DeploymentConfig, - # Feature configs FeatureConfig, - # Middleware configs MiddlewareConfig, - # Extra service configs ExtraServiceConfig, - # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, ): - DEBUG: bool = Field(default=False, description='whether to enable debug mode.') - model_config = SettingsConfigDict( # read from dotenv format config file - env_file='.env', - env_file_encoding='utf-8', + env_file=".env", + env_file_encoding="utf-8", frozen=True, - # ignore extra attributes - extra='ignore', + extra="ignore", ) - CODE_MAX_NUMBER: int = 9223372036854775807 - CODE_MIN_NUMBER: int = -9223372036854775808 - CODE_MAX_STRING_LENGTH: int = 80000 - CODE_MAX_STRING_ARRAY_LENGTH: int = 30 - CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30 - CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000 - - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300 - HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600 - HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600 - HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB' - - HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' - - SSRF_PROXY_HTTP_URL: str | None = None - SSRF_PROXY_HTTPS_URL: str | None = None - - MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') + # Before adding any config, + # please consider to arrange it in the proper config group of existed or added + # for better readability and maintainability. + # Thanks for your concentration and consideration. diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 219b315784323e..10271483c46697 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -6,22 +6,28 @@ class DeploymentConfig(BaseSettings): """ Deployment configs """ + APPLICATION_NAME: str = Field( - description='application name', - default='langgenius/dify', + description="application name", + default="langgenius/dify", + ) + + DEBUG: bool = Field( + description="whether to enable debug mode.", + default=False, ) TESTING: bool = Field( - description='', + description="", default=False, ) EDITION: str = Field( - description='deployment edition', - default='SELF_HOSTED', + description="deployment edition", + default="SELF_HOSTED", ) DEPLOY_ENV: str = Field( - description='deployment environment, default to PRODUCTION.', - default='PRODUCTION', + description="deployment environment, default to PRODUCTION.", + default="PRODUCTION", ) diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index b5d884e10e7b4d..c661593a44264d 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings): Enterprise feature configs. **Before using, please contact business@dify.ai by email to inquire about licensing matters.** """ + ENTERPRISE_ENABLED: bool = Field( - description='whether to enable enterprise features.' - 'Before using, please contact business@dify.ai by email to inquire about licensing matters.', + description="whether to enable enterprise features." + "Before using, please contact business@dify.ai by email to inquire about licensing matters.", default=False, ) CAN_REPLACE_LOGO: bool = Field( - description='whether to allow replacing enterprise logo.', + description="whether to allow replacing enterprise logo.", default=False, ) diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index b77e8adaaeba52..bd1268fa4565f8 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -8,27 +8,28 @@ class NotionConfig(BaseSettings): """ Notion integration configs """ + NOTION_CLIENT_ID: Optional[str] = Field( - description='Notion client ID', + description="Notion client ID", default=None, ) NOTION_CLIENT_SECRET: Optional[str] = Field( - description='Notion client secret key', + description="Notion client secret key", default=None, ) NOTION_INTEGRATION_TYPE: Optional[str] = Field( - description='Notion integration type, default to None, available values: internal.', + description="Notion integration type, default to None, available values: internal.", default=None, ) NOTION_INTERNAL_SECRET: Optional[str] = Field( - description='Notion internal secret key', + description="Notion internal secret key", default=None, ) NOTION_INTEGRATION_TOKEN: Optional[str] = Field( - description='Notion integration token', + description="Notion integration token", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index e6517f730a577a..ea9ea60ffbd6c7 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -8,17 +8,18 @@ class SentryConfig(BaseSettings): """ Sentry configs """ + SENTRY_DSN: Optional[str] = Field( - description='Sentry DSN', + description="Sentry DSN", default=None, ) SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry trace sample rate', + description="Sentry trace sample rate", default=1.0, ) SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry profiles sample rate', + description="Sentry profiles sample rate", default=1.0, ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 369b25d788a440..303bce2aa5050c 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import Annotated, Optional -from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field +from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings): """ Secret Key configs """ + SECRET_KEY: Optional[str] = Field( - description='Your App secret key will be used for securely signing the session cookie' - 'Make sure you are changing this key for your deployment with a strong key.' - 'You can generate a strong key using `openssl rand -base64 42`.' - 'Alternatively you can set it with `SECRET_KEY` environment variable.', + description="Your App secret key will be used for securely signing the session cookie" + "Make sure you are changing this key for your deployment with a strong key." + "You can generate a strong key using `openssl rand -base64 42`." + "Alternatively you can set it with `SECRET_KEY` environment variable.", default=None, ) RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( - description='Expiry time in hours for reset token', + description="Expiry time in hours for reset token", default=24, ) @@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings): """ App Execution configs """ + APP_MAX_EXECUTION_TIME: PositiveInt = Field( - description='execution timeout in seconds for app execution', + description="execution timeout in seconds for app execution", default=1200, ) APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field( - description='max active request per app, 0 means unlimited', + description="max active request per app, 0 means unlimited", default=0, ) @@ -42,14 +44,70 @@ class CodeExecutionSandboxConfig(BaseSettings): """ Code Execution Sandbox configs """ - CODE_EXECUTION_ENDPOINT: str = Field( - description='endpoint URL of code execution servcie', - default='http://sandbox:8194', + + CODE_EXECUTION_ENDPOINT: HttpUrl = Field( + description="endpoint URL of code execution servcie", + default="http://sandbox:8194", ) CODE_EXECUTION_API_KEY: str = Field( - description='API key for code execution service', - default='dify-sandbox', + description="API key for code execution service", + default="dify-sandbox", + ) + + CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( + description="connect timeout in seconds for code execution request", + default=10.0, + ) + + CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( + description="read timeout in seconds for code execution request", + default=60.0, + ) + + CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( + description="write timeout in seconds for code execution request", + default=10.0, + ) + + CODE_MAX_NUMBER: PositiveInt = Field( + description="max depth for code execution", + default=9223372036854775807, + ) + + CODE_MIN_NUMBER: NegativeInt = Field( + description="", + default=-9223372036854775807, + ) + + CODE_MAX_DEPTH: PositiveInt = Field( + description="max depth for code execution", + default=5, + ) + + CODE_MAX_PRECISION: PositiveInt = Field( + description="max precision digits for float type in code execution", + default=20, + ) + + CODE_MAX_STRING_LENGTH: PositiveInt = Field( + description="max string length for code execution", + default=80000, + ) + + CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( + description="", + default=30, + ) + + CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field( + description="", + default=30, + ) + + CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field( + description="", + default=1000, ) @@ -57,28 +115,27 @@ class EndpointConfig(BaseSettings): """ Module URL configs """ + CONSOLE_API_URL: str = Field( - description='The backend URL prefix of the console API.' - 'used to concatenate the login authorization callback or notion integration callback.', - default='', + description="The backend URL prefix of the console API." + "used to concatenate the login authorization callback or notion integration callback.", + default="", ) CONSOLE_WEB_URL: str = Field( - description='The front-end URL prefix of the console web.' - 'used to concatenate some front-end addresses and for CORS configuration use.', - default='', + description="The front-end URL prefix of the console web." + "used to concatenate some front-end addresses and for CORS configuration use.", + default="", ) SERVICE_API_URL: str = Field( - description='Service API Url prefix.' - 'used to display Service API Base Url to the front-end.', - default='', + description="Service API Url prefix." "used to display Service API Base Url to the front-end.", + default="", ) APP_WEB_URL: str = Field( - description='WebApp Url prefix.' - 'used to display WebAPP API Base Url to the front-end.', - default='', + description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.", + default="", ) @@ -86,17 +143,18 @@ class FileAccessConfig(BaseSettings): """ File Access configs """ + FILES_URL: str = Field( - description='File preview or download Url prefix.' - ' used to display File preview or download Url to the front-end or as Multi-model inputs;' - 'Url is signed and has expiration time.', - validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'), + description="File preview or download Url prefix." + " used to display File preview or download Url to the front-end or as Multi-model inputs;" + "Url is signed and has expiration time.", + validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"), alias_priority=1, - default='', + default="", ) FILES_ACCESS_TIMEOUT: int = Field( - description='timeout in seconds for file accessing', + description="timeout in seconds for file accessing", default=300, ) @@ -105,23 +163,24 @@ class FileUploadConfig(BaseSettings): """ File Uploading configs """ + UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='size limit in Megabytes for uploading files', + description="size limit in Megabytes for uploading files", default=15, ) UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field( - description='batch size limit for uploading files', + description="batch size limit for uploading files", default=5, ) UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='image file size limit in Megabytes for uploading files', + description="image file size limit in Megabytes for uploading files", default=10, ) BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( - description='', # todo: to be clarified + description="", # todo: to be clarified default=20, ) @@ -130,45 +189,79 @@ class HttpConfig(BaseSettings): """ HTTP configs """ + API_COMPRESSION_ENABLED: bool = Field( - description='whether to enable HTTP response compression of gzip', + description="whether to enable HTTP response compression of gzip", default=False, ) inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'), - default='', + description="", + validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"), + default="", ) @computed_field @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',') + return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'), - default='*', + description="", + validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"), + default="*", ) @computed_field @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',') + return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") + + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request") + ] = 10 + + HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ + PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request") + ] = 60 + + HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request") + ] = 20 + + HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( + description="", + default=10 * 1024 * 1024, + ) + + HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field( + description="", + default=1 * 1024 * 1024, + ) + + SSRF_PROXY_HTTP_URL: Optional[str] = Field( + description="HTTP URL for SSRF proxy", + default=None, + ) + + SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + description="HTTPS URL for SSRF proxy", + default=None, + ) class InnerAPIConfig(BaseSettings): """ Inner API configs """ + INNER_API: bool = Field( - description='whether to enable the inner API', + description="whether to enable the inner API", default=False, ) INNER_API_KEY: Optional[str] = Field( - description='The inner API key is used to authenticate the inner API', + description="The inner API key is used to authenticate the inner API", default=None, ) @@ -179,28 +272,27 @@ class LoggingConfig(BaseSettings): """ LOG_LEVEL: str = Field( - description='Log output level, default to INFO.' - 'It is recommended to set it to ERROR for production.', - default='INFO', + description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.", + default="INFO", ) LOG_FILE: Optional[str] = Field( - description='logging output file path', + description="logging output file path", default=None, ) LOG_FORMAT: str = Field( - description='log format', - default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s', + description="log format", + default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) LOG_DATEFORMAT: Optional[str] = Field( - description='log date format', + description="log date format", default=None, ) LOG_TZ: Optional[str] = Field( - description='specify log timezone, eg: America/New_York', + description="specify log timezone, eg: America/New_York", default=None, ) @@ -209,8 +301,9 @@ class ModelLoadBalanceConfig(BaseSettings): """ Model load balance configs """ + MODEL_LB_ENABLED: bool = Field( - description='whether to enable model load balancing', + description="whether to enable model load balancing", default=False, ) @@ -219,8 +312,9 @@ class BillingConfig(BaseSettings): """ Platform Billing Configurations """ + BILLING_ENABLED: bool = Field( - description='whether to enable billing', + description="whether to enable billing", default=False, ) @@ -229,9 +323,10 @@ class UpdateConfig(BaseSettings): """ Update configs """ + CHECK_UPDATE_URL: str = Field( - description='url for checking updates', - default='https://updates.dify.ai', + description="url for checking updates", + default="https://updates.dify.ai", ) @@ -241,47 +336,53 @@ class WorkflowConfig(BaseSettings): """ WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field( - description='max execution steps in single workflow execution', + description="max execution steps in single workflow execution", default=500, ) WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field( - description='max execution time in seconds in single workflow execution', + description="max execution time in seconds in single workflow execution", default=1200, ) WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field( - description='max depth of calling in single workflow execution', + description="max depth of calling in single workflow execution", default=5, ) + MAX_VARIABLE_SIZE: PositiveInt = Field( + description="The maximum size in bytes of a variable. default to 5KB.", + default=5 * 1024, + ) + class OAuthConfig(BaseSettings): """ oauth configs """ + OAUTH_REDIRECT_PATH: str = Field( - description='redirect path for OAuth', - default='/console/api/oauth/authorize', + description="redirect path for OAuth", + default="/console/api/oauth/authorize", ) GITHUB_CLIENT_ID: Optional[str] = Field( - description='GitHub client id for OAuth', + description="GitHub client id for OAuth", default=None, ) GITHUB_CLIENT_SECRET: Optional[str] = Field( - description='GitHub client secret key for OAuth', + description="GitHub client secret key for OAuth", default=None, ) GOOGLE_CLIENT_ID: Optional[str] = Field( - description='Google client id for OAuth', + description="Google client id for OAuth", default=None, ) GOOGLE_CLIENT_SECRET: Optional[str] = Field( - description='Google client secret key for OAuth', + description="Google client secret key for OAuth", default=None, ) @@ -291,9 +392,8 @@ class ModerationConfig(BaseSettings): Moderation in app configs. """ - # todo: to be clarified in usage and unit - OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field( - description='buffer size for moderation', + MODERATION_BUFFER_SIZE: PositiveInt = Field( + description="buffer size for moderation", default=300, ) @@ -304,7 +404,7 @@ class ToolConfig(BaseSettings): """ TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field( - description='max age in seconds for tool icon caching', + description="max age in seconds for tool icon caching", default=3600, ) @@ -315,52 +415,52 @@ class MailConfig(BaseSettings): """ MAIL_TYPE: Optional[str] = Field( - description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.', + description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.", default=None, ) MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( - description='default email address for sending from ', + description="default email address for sending from ", default=None, ) RESEND_API_KEY: Optional[str] = Field( - description='API key for Resend', + description="API key for Resend", default=None, ) RESEND_API_URL: Optional[str] = Field( - description='API URL for Resend', + description="API URL for Resend", default=None, ) SMTP_SERVER: Optional[str] = Field( - description='smtp server host', + description="smtp server host", default=None, ) SMTP_PORT: Optional[int] = Field( - description='smtp server port', + description="smtp server port", default=465, ) SMTP_USERNAME: Optional[str] = Field( - description='smtp server username', + description="smtp server username", default=None, ) SMTP_PASSWORD: Optional[str] = Field( - description='smtp server password', + description="smtp server password", default=None, ) SMTP_USE_TLS: bool = Field( - description='whether to use TLS connection to smtp server', + description="whether to use TLS connection to smtp server", default=False, ) SMTP_OPPORTUNISTIC_TLS: bool = Field( - description='whether to use opportunistic TLS connection to smtp server', + description="whether to use opportunistic TLS connection to smtp server", default=False, ) @@ -371,22 +471,22 @@ class RagEtlConfig(BaseSettings): """ ETL_TYPE: str = Field( - description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ', - default='dify', + description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ", + default="dify", ) KEYWORD_DATA_SOURCE_TYPE: str = Field( - description='source type for keyword data, default to `database`, available values are `database` .', - default='database', + description="source type for keyword data, default to `database`, available values are `database` .", + default="database", ) UNSTRUCTURED_API_URL: Optional[str] = Field( - description='API URL for Unstructured', + description="API URL for Unstructured", default=None, ) UNSTRUCTURED_API_KEY: Optional[str] = Field( - description='API key for Unstructured', + description="API key for Unstructured", default=None, ) @@ -397,22 +497,23 @@ class DataSetConfig(BaseSettings): """ CLEAN_DAY_SETTING: PositiveInt = Field( - description='interval in days for cleaning up dataset', + description="interval in days for cleaning up dataset", default=30, ) DATASET_OPERATOR_ENABLED: bool = Field( - description='whether to enable dataset operator', + description="whether to enable dataset operator", default=False, ) + class WorkspaceConfig(BaseSettings): """ Workspace configs """ INVITE_EXPIRY_HOURS: PositiveInt = Field( - description='workspaces invitation expiration in hours', + description="workspaces invitation expiration in hours", default=72, ) @@ -423,25 +524,81 @@ class IndexingConfig(BaseSettings): """ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field( - description='max segmentation token length for indexing', + description="max segmentation token length for indexing", default=1000, ) class ImageFormatConfig(BaseSettings): MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( - description='multi model send image format, support base64, url, default is base64', - default='base64', + description="multi model send image format, support base64, url, default is base64", + default="base64", ) class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( - description='the time of the celery scheduler, default to 1 day', + description="the time of the celery scheduler, default to 1 day", default=1, ) +class PositionConfig(BaseSettings): + POSITION_PROVIDER_PINS: str = Field( + description="The heads of model providers", + default="", + ) + + POSITION_PROVIDER_INCLUDES: str = Field( + description="The included model providers", + default="", + ) + + POSITION_PROVIDER_EXCLUDES: str = Field( + description="The excluded model providers", + default="", + ) + + POSITION_TOOL_PINS: str = Field( + description="The heads of tools", + default="", + ) + + POSITION_TOOL_INCLUDES: str = Field( + description="The included tools", + default="", + ) + + POSITION_TOOL_EXCLUDES: str = Field( + description="The excluded tools", + default="", + ) + + @computed_field + def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] + + @computed_field + def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_TOOL_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] + + @computed_field + def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -466,7 +623,7 @@ class FeatureConfig( UpdateConfig, WorkflowConfig, WorkspaceConfig, - + PositionConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 88fe188587e7a1..f269d0ab9c89cd 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings): """ HOSTED_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_TRIAL_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_OPENAI_TRIAL_MODELS: str = Field( - description='', - default='gpt-3.5-turbo,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-instruct,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'text-davinci-003', + description="", + default="gpt-3.5-turbo," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-instruct," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "text-davinci-003", ) HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=200, ) HOSTED_OPENAI_PAID_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_OPENAI_PAID_MODELS: str = Field( - description='', - default='gpt-4,' - 'gpt-4-turbo-preview,' - 'gpt-4-turbo-2024-04-09,' - 'gpt-4-1106-preview,' - 'gpt-4-0125-preview,' - 'gpt-3.5-turbo,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'gpt-3.5-turbo-instruct,' - 'text-davinci-003', + description="", + default="gpt-4," + "gpt-4-turbo-preview," + "gpt-4-turbo-2024-04-09," + "gpt-4-1106-preview," + "gpt-4-0125-preview," + "gpt-3.5-turbo," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "gpt-3.5-turbo-instruct," + "text-davinci-003", ) @@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings): """ HOSTED_AZURE_OPENAI_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=200, ) @@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings): """ HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=600000, ) HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings): """ HOSTED_MINIMAX_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings): """ HOSTED_SPARK_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings): """ HOSTED_ZHIPUAI_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings): """ HOSTED_MODERATION_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_MODERATION_PROVIDERS: str = Field( - description='', - default='', + description="", + default="", ) @@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings): """ HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field( - description='the mode for fetching app templates,' - ' default to remote,' - ' available values: remote, db, builtin', - default='remote', + description="the mode for fetching app templates," + " default to remote," + " available values: remote, db, builtin", + default="remote", ) HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field( - description='the domain for fetching remote app templates', - default='https://tmpl.dify.ai', + description="the domain for fetching remote app templates", + default="https://tmpl.dify.ai", ) @@ -202,7 +202,6 @@ class HostedServiceConfig( HostedOpenAiConfig, HostedSparkConfig, HostedZhipuAIConfig, - # moderation HostedModerationConfig, ): diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 07688e9aebfdc4..f25979e5d8f775 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -13,6 +13,7 @@ from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.chroma_config import ChromaConfig +from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig @@ -28,108 +29,108 @@ class StorageConfig(BaseSettings): STORAGE_TYPE: str = Field( - description='storage type,' - ' default to `local`,' - ' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.', - default='local', + description="storage type," + " default to `local`," + " available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.", + default="local", ) STORAGE_LOCAL_PATH: str = Field( - description='local storage path', - default='storage', + description="local storage path", + default="storage", ) class VectorStoreConfig(BaseSettings): VECTOR_STORE: Optional[str] = Field( - description='vector store type', + description="vector store type", default=None, ) class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( - description='keyword store type', - default='jieba', + description="keyword store type", + default="jieba", ) class DatabaseConfig: DB_HOST: str = Field( - description='db host', - default='localhost', + description="db host", + default="localhost", ) DB_PORT: PositiveInt = Field( - description='db port', + description="db port", default=5432, ) DB_USERNAME: str = Field( - description='db username', - default='postgres', + description="db username", + default="postgres", ) DB_PASSWORD: str = Field( - description='db password', - default='', + description="db password", + default="", ) DB_DATABASE: str = Field( - description='db database', - default='dify', + description="db database", + default="dify", ) DB_CHARSET: str = Field( - description='db charset', - default='', + description="db charset", + default="", ) DB_EXTRAS: str = Field( - description='db extras options. Example: keepalives_idle=60&keepalives=1', - default='', + description="db extras options. Example: keepalives_idle=60&keepalives=1", + default="", ) SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( - description='db uri scheme', - default='postgresql', + description="db uri scheme", + default="postgresql", ) @computed_field @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( - f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" - if self.DB_CHARSET - else self.DB_EXTRAS + f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS ).strip("&") db_extras = f"?{db_extras}" if db_extras else "" - return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" - f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" - f"{db_extras}") + return ( + f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" + f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" + f"{db_extras}" + ) SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( - description='pool size of SqlAlchemy', + description="pool size of SqlAlchemy", default=30, ) SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field( - description='max overflows for SqlAlchemy', + description="max overflows for SqlAlchemy", default=10, ) SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field( - description='SqlAlchemy pool recycle', + description="SqlAlchemy pool recycle", default=3600, ) SQLALCHEMY_POOL_PRE_PING: bool = Field( - description='whether to enable pool pre-ping in SqlAlchemy', + description="whether to enable pool pre-ping in SqlAlchemy", default=False, ) SQLALCHEMY_ECHO: bool | str = Field( - description='whether to enable SqlAlchemy echo', + description="whether to enable SqlAlchemy echo", default=False, ) @@ -137,35 +138,38 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { - 'pool_size': self.SQLALCHEMY_POOL_SIZE, - 'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW, - 'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE, - 'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING, - 'connect_args': {'options': '-c timezone=UTC'}, + "pool_size": self.SQLALCHEMY_POOL_SIZE, + "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, + "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, + "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, + "connect_args": {"options": "-c timezone=UTC"}, } class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description='Celery backend, available values are `database`, `redis`', - default='database', + description="Celery backend, available values are `database`, `redis`", + default="database", ) CELERY_BROKER_URL: Optional[str] = Field( - description='CELERY_BROKER_URL', + description="CELERY_BROKER_URL", default=None, ) @computed_field @property def CELERY_RESULT_BACKEND(self) -> str | None: - return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \ - if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL + return ( + "db+{}".format(self.SQLALCHEMY_DATABASE_URI) + if self.CELERY_BACKEND == "database" + else self.CELERY_BROKER_URL + ) @computed_field @property def BROKER_USE_SSL(self) -> bool: - return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False + return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False class MiddlewareConfig( @@ -174,7 +178,6 @@ class MiddlewareConfig( DatabaseConfig, KeywordStoreConfig, RedisConfig, - # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, @@ -183,7 +186,6 @@ class MiddlewareConfig( TencentCloudCOSStorageConfig, S3StorageConfig, OCIStorageConfig, - # configs of vdb and vdb providers VectorStoreConfig, AnalyticdbConfig, @@ -199,5 +201,6 @@ class MiddlewareConfig( TencentVectorDBConfig, TiDBVectorConfig, WeaviateConfig, + ElasticsearchConfig, ): pass diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 436ba5d4c01f5c..cacdaf6fb6df2a 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -8,32 +8,33 @@ class RedisConfig(BaseSettings): """ Redis configs """ + REDIS_HOST: str = Field( - description='Redis host', - default='localhost', + description="Redis host", + default="localhost", ) REDIS_PORT: PositiveInt = Field( - description='Redis port', + description="Redis port", default=6379, ) REDIS_USERNAME: Optional[str] = Field( - description='Redis username', + description="Redis username", default=None, ) REDIS_PASSWORD: Optional[str] = Field( - description='Redis password', + description="Redis password", default=None, ) REDIS_DB: NonNegativeInt = Field( - description='Redis database id, default to 0', + description="Redis database id, default to 0", default=0, ) REDIS_USE_SSL: bool = Field( - description='whether to use SSL for Redis connection', + description="whether to use SSL for Redis connection", default=False, ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 19e6cafb1282b1..c1843dc26cee00 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -10,31 +10,36 @@ class AliyunOSSStorageConfig(BaseSettings): """ ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( - description='Aliyun OSS bucket name', + description="Aliyun OSS bucket name", default=None, ) ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( - description='Aliyun OSS access key', + description="Aliyun OSS access key", default=None, ) ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( - description='Aliyun OSS secret key', + description="Aliyun OSS secret key", default=None, ) ALIYUN_OSS_ENDPOINT: Optional[str] = Field( - description='Aliyun OSS endpoint URL', + description="Aliyun OSS endpoint URL", default=None, ) ALIYUN_OSS_REGION: Optional[str] = Field( - description='Aliyun OSS region', + description="Aliyun OSS region", default=None, ) ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( - description='Aliyun OSS authentication version', + description="Aliyun OSS authentication version", + default=None, + ) + + ALIYUN_OSS_PATH: Optional[str] = Field( + description="Aliyun OSS path", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index 2566fbd5da6b96..bef93261087092 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings): """ S3_ENDPOINT: Optional[str] = Field( - description='S3 storage endpoint', + description="S3 storage endpoint", default=None, ) S3_REGION: Optional[str] = Field( - description='S3 storage region', + description="S3 storage region", default=None, ) S3_BUCKET_NAME: Optional[str] = Field( - description='S3 storage bucket name', + description="S3 storage bucket name", default=None, ) S3_ACCESS_KEY: Optional[str] = Field( - description='S3 storage access key', + description="S3 storage access key", default=None, ) S3_SECRET_KEY: Optional[str] = Field( - description='S3 storage secret key', + description="S3 storage secret key", default=None, ) S3_ADDRESS_STYLE: str = Field( - description='S3 storage address style', - default='auto', + description="S3 storage address style", + default="auto", ) S3_USE_AWS_MANAGED_IAM: bool = Field( - description='whether to use aws managed IAM for S3', + description="whether to use aws managed IAM for S3", default=False, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index 26e441c89bd4e4..10944b58eddc61 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings): """ AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( - description='Azure Blob account name', + description="Azure Blob account name", default=None, ) AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( - description='Azure Blob account key', + description="Azure Blob account key", default=None, ) AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( - description='Azure Blob container name', + description="Azure Blob container name", default=None, ) AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( - description='Azure Blob account URL', + description="Azure Blob account URL", default=None, ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e1b0e34e0c32fd..10a2d97e8dbcdf 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings): """ GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( - description='Google Cloud storage bucket name', + description="Google Cloud storage bucket name", default=None, ) GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( - description='Google Cloud storage service account json base64', + description="Google Cloud storage service account json base64", default=None, ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index 6c0c06746954f0..f8993496c952b8 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings): """ OCI_ENDPOINT: Optional[str] = Field( - description='OCI storage endpoint', + description="OCI storage endpoint", default=None, ) OCI_REGION: Optional[str] = Field( - description='OCI storage region', + description="OCI storage region", default=None, ) OCI_BUCKET_NAME: Optional[str] = Field( - description='OCI storage bucket name', + description="OCI storage bucket name", default=None, ) OCI_ACCESS_KEY: Optional[str] = Field( - description='OCI storage access key', + description="OCI storage access key", default=None, ) OCI_SECRET_KEY: Optional[str] = Field( - description='OCI storage secret key', + description="OCI storage secret key", default=None, ) - diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 1060c7b93e0bf1..765ac08f3e4047 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings): """ TENCENT_COS_BUCKET_NAME: Optional[str] = Field( - description='Tencent Cloud COS bucket name', + description="Tencent Cloud COS bucket name", default=None, ) TENCENT_COS_REGION: Optional[str] = Field( - description='Tencent Cloud COS region', + description="Tencent Cloud COS region", default=None, ) TENCENT_COS_SECRET_ID: Optional[str] = Field( - description='Tencent Cloud COS secret id', + description="Tencent Cloud COS secret id", default=None, ) TENCENT_COS_SECRET_KEY: Optional[str] = Field( - description='Tencent Cloud COS secret key', + description="Tencent Cloud COS secret key", default=None, ) TENCENT_COS_SCHEME: Optional[str] = Field( - description='Tencent Cloud COS scheme', + description="Tencent Cloud COS scheme", default=None, ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index db2899265e204f..04f5b0e5bf7697 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel): https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID : Optional[str] = Field( - default=None, - description="The Access Key ID provided by Alibaba Cloud for authentication." + ANALYTICDB_KEY_ID: Optional[str] = Field( + default=None, description="The Access Key ID provided by Alibaba Cloud for authentication." ) - ANALYTICDB_KEY_SECRET : Optional[str] = Field( - default=None, - description="The Secret Access Key corresponding to the Access Key ID for secure access." + ANALYTICDB_KEY_SECRET: Optional[str] = Field( + default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access." ) - ANALYTICDB_REGION_ID : Optional[str] = Field( - default=None, - description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." + ANALYTICDB_REGION_ID: Optional[str] = Field( + default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." ) - ANALYTICDB_INSTANCE_ID : Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: Optional[str] = Field( default=None, - description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').." + description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..", ) - ANALYTICDB_ACCOUNT : Optional[str] = Field( - default=None, - description="The account name used to log in to the AnalyticDB instance." + ANALYTICDB_ACCOUNT: Optional[str] = Field( + default=None, description="The account name used to log in to the AnalyticDB instance." ) - ANALYTICDB_PASSWORD : Optional[str] = Field( - default=None, - description="The password associated with the AnalyticDB account for authentication." + ANALYTICDB_PASSWORD: Optional[str] = Field( + default=None, description="The password associated with the AnalyticDB account for authentication." ) - ANALYTICDB_NAMESPACE : Optional[str] = Field( - default=None, - description="The namespace within AnalyticDB for schema isolation." + ANALYTICDB_NAMESPACE: Optional[str] = Field( + default=None, description="The namespace within AnalyticDB for schema isolation." ) - ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field( - default=None, - description="The password for accessing the specified namespace within the AnalyticDB instance." + ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( + default=None, description="The password for accessing the specified namespace within the AnalyticDB instance." ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index f365879efb1a19..d386623a569d38 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings): """ CHROMA_HOST: Optional[str] = Field( - description='Chroma host', + description="Chroma host", default=None, ) CHROMA_PORT: PositiveInt = Field( - description='Chroma port', + description="Chroma port", default=8000, ) CHROMA_TENANT: Optional[str] = Field( - description='Chroma database', + description="Chroma database", default=None, ) CHROMA_DATABASE: Optional[str] = Field( - description='Chroma database', + description="Chroma database", default=None, ) CHROMA_AUTH_PROVIDER: Optional[str] = Field( - description='Chroma authentication provider', + description="Chroma authentication provider", default=None, ) CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( - description='Chroma authentication credentials', + description="Chroma authentication credentials", default=None, ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py new file mode 100644 index 00000000000000..5b6a8fd939c292 --- /dev/null +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -0,0 +1,30 @@ +from typing import Optional + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class ElasticsearchConfig(BaseSettings): + """ + Elasticsearch configs + """ + + ELASTICSEARCH_HOST: Optional[str] = Field( + description="Elasticsearch host", + default="127.0.0.1", + ) + + ELASTICSEARCH_PORT: PositiveInt = Field( + description="Elasticsearch port", + default=9200, + ) + + ELASTICSEARCH_USERNAME: Optional[str] = Field( + description="Elasticsearch username", + default="elastic", + ) + + ELASTICSEARCH_PASSWORD: Optional[str] = Field( + description="Elasticsearch password", + default="elastic", + ) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 01502d45901764..85466cd5cc0dab 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings): """ MILVUS_HOST: Optional[str] = Field( - description='Milvus host', + description="Milvus host", default=None, ) MILVUS_PORT: PositiveInt = Field( - description='Milvus RestFul API port', + description="Milvus RestFul API port", default=9091, ) MILVUS_USER: Optional[str] = Field( - description='Milvus user', + description="Milvus user", default=None, ) MILVUS_PASSWORD: Optional[str] = Field( - description='Milvus password', + description="Milvus password", default=None, ) MILVUS_SECURE: bool = Field( - description='whether to use SSL connection for Milvus', + description="whether to use SSL connection for Milvus", default=False, ) MILVUS_DATABASE: str = Field( - description='Milvus database, default to `default`', - default='default', + description="Milvus database, default to `default`", + default="default", ) diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index 895cd6f1769d6e..6451d26e1c65da 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field, PositiveInt @@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel): """ MYSCALE_HOST: str = Field( - description='MyScale host', - default='localhost', + description="MyScale host", + default="localhost", ) MYSCALE_PORT: PositiveInt = Field( - description='MyScale port', + description="MyScale port", default=8123, ) MYSCALE_USER: str = Field( - description='MyScale user', - default='default', + description="MyScale user", + default="default", ) MYSCALE_PASSWORD: str = Field( - description='MyScale password', - default='', + description="MyScale password", + default="", ) MYSCALE_DATABASE: str = Field( - description='MyScale database name', - default='default', + description="MyScale database name", + default="default", ) MYSCALE_FTS_PARAMS: str = Field( - description='MyScale fts index parameters', - default='', + description="MyScale fts index parameters", + default="", ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 15d6f5b6a97126..5823dc1433d83c 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings): """ OPENSEARCH_HOST: Optional[str] = Field( - description='OpenSearch host', + description="OpenSearch host", default=None, ) OPENSEARCH_PORT: PositiveInt = Field( - description='OpenSearch port', + description="OpenSearch port", default=9200, ) OPENSEARCH_USER: Optional[str] = Field( - description='OpenSearch user', + description="OpenSearch user", default=None, ) OPENSEARCH_PASSWORD: Optional[str] = Field( - description='OpenSearch password', + description="OpenSearch password", default=None, ) OPENSEARCH_SECURE: bool = Field( - description='whether to use SSL connection for OpenSearch', + description="whether to use SSL connection for OpenSearch", default=False, ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index 888fc19492d67e..62614ae870cc89 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -10,26 +10,26 @@ class OracleConfig(BaseSettings): """ ORACLE_HOST: Optional[str] = Field( - description='ORACLE host', + description="ORACLE host", default=None, ) ORACLE_PORT: Optional[PositiveInt] = Field( - description='ORACLE port', + description="ORACLE port", default=1521, ) ORACLE_USER: Optional[str] = Field( - description='ORACLE user', + description="ORACLE user", default=None, ) ORACLE_PASSWORD: Optional[str] = Field( - description='ORACLE password', + description="ORACLE password", default=None, ) ORACLE_DATABASE: Optional[str] = Field( - description='ORACLE database', + description="ORACLE database", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 8a677f60a3a851..39a7c1d8d5e4ad 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings): """ PGVECTOR_HOST: Optional[str] = Field( - description='PGVector host', + description="PGVector host", default=None, ) PGVECTOR_PORT: Optional[PositiveInt] = Field( - description='PGVector port', + description="PGVector port", default=5433, ) PGVECTOR_USER: Optional[str] = Field( - description='PGVector user', + description="PGVector user", default=None, ) PGVECTOR_PASSWORD: Optional[str] = Field( - description='PGVector password', + description="PGVector password", default=None, ) PGVECTOR_DATABASE: Optional[str] = Field( - description='PGVector database', + description="PGVector database", default=None, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index 39f52f22ff6c95..c40e5ff92103fd 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings): """ PGVECTO_RS_HOST: Optional[str] = Field( - description='PGVectoRS host', + description="PGVectoRS host", default=None, ) PGVECTO_RS_PORT: Optional[PositiveInt] = Field( - description='PGVectoRS port', + description="PGVectoRS port", default=5431, ) PGVECTO_RS_USER: Optional[str] = Field( - description='PGVectoRS user', + description="PGVectoRS user", default=None, ) PGVECTO_RS_PASSWORD: Optional[str] = Field( - description='PGVectoRS password', + description="PGVectoRS password", default=None, ) PGVECTO_RS_DATABASE: Optional[str] = Field( - description='PGVectoRS database', + description="PGVectoRS database", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index c85bf9c7dc6047..27f75491c98767 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings): """ QDRANT_URL: Optional[str] = Field( - description='Qdrant url', + description="Qdrant url", default=None, ) QDRANT_API_KEY: Optional[str] = Field( - description='Qdrant api key', + description="Qdrant api key", default=None, ) QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( - description='Qdrant client timeout in seconds', + description="Qdrant client timeout in seconds", default=20, ) QDRANT_GRPC_ENABLED: bool = Field( - description='whether enable grpc support for Qdrant connection', + description="whether enable grpc support for Qdrant connection", default=False, ) QDRANT_GRPC_PORT: PositiveInt = Field( - description='Qdrant grpc port', + description="Qdrant grpc port", default=6334, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index be93185f3ccab1..66b9ecc03f6aa6 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -10,26 +10,26 @@ class RelytConfig(BaseSettings): """ RELYT_HOST: Optional[str] = Field( - description='Relyt host', + description="Relyt host", default=None, ) RELYT_PORT: PositiveInt = Field( - description='Relyt port', + description="Relyt port", default=9200, ) RELYT_USER: Optional[str] = Field( - description='Relyt user', + description="Relyt user", default=None, ) RELYT_PASSWORD: Optional[str] = Field( - description='Relyt password', + description="Relyt password", default=None, ) RELYT_DATABASE: Optional[str] = Field( - description='Relyt database', - default='default', + description="Relyt database", + default="default", ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index 531ec840686eea..46b4cb6a24ff88 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings): """ TENCENT_VECTOR_DB_URL: Optional[str] = Field( - description='Tencent Vector URL', + description="Tencent Vector URL", default=None, ) TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( - description='Tencent Vector API key', + description="Tencent Vector API key", default=None, ) TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field( - description='Tencent Vector timeout in seconds', + description="Tencent Vector timeout in seconds", default=30, ) TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( - description='Tencent Vector username', + description="Tencent Vector username", default=None, ) TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( - description='Tencent Vector password', + description="Tencent Vector password", default=None, ) TENCENT_VECTOR_DB_SHARD: PositiveInt = Field( - description='Tencent Vector sharding number', + description="Tencent Vector sharding number", default=1, ) TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field( - description='Tencent Vector replicas', + description="Tencent Vector replicas", default=2, ) TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( - description='Tencent Vector Database', + description="Tencent Vector Database", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index 8d459691a895bd..dbcb276c01a1f1 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings): """ TIDB_VECTOR_HOST: Optional[str] = Field( - description='TiDB Vector host', + description="TiDB Vector host", default=None, ) TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( - description='TiDB Vector port', + description="TiDB Vector port", default=4000, ) TIDB_VECTOR_USER: Optional[str] = Field( - description='TiDB Vector user', + description="TiDB Vector user", default=None, ) TIDB_VECTOR_PASSWORD: Optional[str] = Field( - description='TiDB Vector password', + description="TiDB Vector password", default=None, ) TIDB_VECTOR_DATABASE: Optional[str] = Field( - description='TiDB Vector database', + description="TiDB Vector database", default=None, ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index b985ecea121f9e..63d1022f6a4516 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings): """ WEAVIATE_ENDPOINT: Optional[str] = Field( - description='Weaviate endpoint URL', + description="Weaviate endpoint URL", default=None, ) WEAVIATE_API_KEY: Optional[str] = Field( - description='Weaviate API key', + description="Weaviate API key", default=None, ) WEAVIATE_GRPC_ENABLED: bool = Field( - description='whether to enable gRPC for Weaviate connection', + description="whether to enable gRPC for Weaviate connection", default=True, ) WEAVIATE_BATCH_SIZE: PositiveInt = Field( - description='Weaviate batch size', + description="Weaviate batch size", default=100, ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 13c55ca4251118..dd096716120c69 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings): """ CURRENT_VERSION: str = Field( - description='Dify version', - default='0.6.15', + description="Dify version", + default="0.7.2", ) COMMIT_SHA: str = Field( description="SHA-1 checksum of the git commit used to build the app", - default='', + default="", ) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 08a27869948c95..e22c3268ef428b 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,2 +1 @@ -# TODO: Update all string in code to use this constant -HIDDEN_VALUE = '[__HIDDEN__]' \ No newline at end of file +HIDDEN_VALUE = "[__HIDDEN__]" diff --git a/api/constants/languages.py b/api/constants/languages.py index 023d2f18a6f08a..524dc61b5790a4 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,21 +1,22 @@ language_timezone_mapping = { - 'en-US': 'America/New_York', - 'zh-Hans': 'Asia/Shanghai', - 'zh-Hant': 'Asia/Taipei', - 'pt-BR': 'America/Sao_Paulo', - 'es-ES': 'Europe/Madrid', - 'fr-FR': 'Europe/Paris', - 'de-DE': 'Europe/Berlin', - 'ja-JP': 'Asia/Tokyo', - 'ko-KR': 'Asia/Seoul', - 'ru-RU': 'Europe/Moscow', - 'it-IT': 'Europe/Rome', - 'uk-UA': 'Europe/Kyiv', - 'vi-VN': 'Asia/Ho_Chi_Minh', - 'ro-RO': 'Europe/Bucharest', - 'pl-PL': 'Europe/Warsaw', - 'hi-IN': 'Asia/Kolkata', - 'tr-TR': 'Europe/Istanbul', + "en-US": "America/New_York", + "zh-Hans": "Asia/Shanghai", + "zh-Hant": "Asia/Taipei", + "pt-BR": "America/Sao_Paulo", + "es-ES": "Europe/Madrid", + "fr-FR": "Europe/Paris", + "de-DE": "Europe/Berlin", + "ja-JP": "Asia/Tokyo", + "ko-KR": "Asia/Seoul", + "ru-RU": "Europe/Moscow", + "it-IT": "Europe/Rome", + "uk-UA": "Europe/Kyiv", + "vi-VN": "Asia/Ho_Chi_Minh", + "ro-RO": "Europe/Bucharest", + "pl-PL": "Europe/Warsaw", + "hi-IN": "Asia/Kolkata", + "tr-TR": "Europe/Istanbul", + "fa-IR": "Asia/Tehran", } languages = list(language_timezone_mapping.keys()) @@ -25,6 +26,5 @@ def supported_language(lang): if lang in languages: return lang - error = ('{lang} is not a valid language.' - .format(lang=lang)) + error = "{lang} is not a valid language.".format(lang=lang) raise ValueError(error) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index cc5a37025479fd..7e1a196356c4e2 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -5,82 +5,79 @@ default_app_templates = { # workflow default mode AppMode.WORKFLOW: { - 'app': { - 'mode': AppMode.WORKFLOW.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.WORKFLOW.value, + "enable_site": True, + "enable_api": True, } }, - # completion default mode AppMode.COMPLETION: { - 'app': { - 'mode': AppMode.COMPLETION.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.COMPLETION.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} + "completion_params": {}, }, - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + "user_input_form": json.dumps( + [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": True, + "default": "", + }, + }, + ] + ), + "pre_prompt": "{{query}}", }, - }, - # chat default mode AppMode.CHAT: { - 'app': { - 'mode': AppMode.CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } + "completion_params": {}, + }, + }, }, - # advanced-chat default mode AppMode.ADVANCED_CHAT: { - 'app': { - 'mode': AppMode.ADVANCED_CHAT.value, - 'enable_site': True, - 'enable_api': True - } + "app": { + "mode": AppMode.ADVANCED_CHAT.value, + "enable_site": True, + "enable_api": True, + }, }, - # agent-chat default mode AppMode.AGENT_CHAT: { - 'app': { - 'mode': AppMode.AGENT_CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.AGENT_CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } - } + "completion_params": {}, + }, + }, + }, } diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 306fac3a931298..623a1a28eb731e 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,3 +1,7 @@ from contextvars import ContextVar -tenant_id: ContextVar[str] = ContextVar('tenant_id') \ No newline at end of file +from core.workflow.entities.variable_pool import VariablePool + +tenant_id: ContextVar[str] = ContextVar("tenant_id") + +workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index b28b04f643122b..8b137891791fe9 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,3 +1 @@ - - diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index bef40bea7eb32e..eb7c1464d39722 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,7 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('console', __name__, url_prefix='/console/api') +bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) # Import other controllers @@ -17,6 +17,7 @@ audio, completion, conversation, + conversation_variables, generator, message, model_config, diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 028be5de548b7d..a4ceec26620dc9 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,24 +15,24 @@ def admin_required(view): @wraps(view) def decorated(*args, **kwargs): - if not os.getenv('ADMIN_API_KEY'): - raise Unauthorized('API key is invalid.') + if not os.getenv("ADMIN_API_KEY"): + raise Unauthorized("API key is invalid.") - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if os.getenv('ADMIN_API_KEY') != auth_token: - raise Unauthorized('API key is invalid.') + if os.getenv("ADMIN_API_KEY") != auth_token: + raise Unauthorized("API key is invalid.") return view(*args, **kwargs) @@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource): @admin_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') - parser.add_argument('desc', type=str, location='json') - parser.add_argument('copyright', type=str, location='json') - parser.add_argument('privacy_policy', type=str, location='json') - parser.add_argument('custom_disclaimer', type=str, location='json') - parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('position', type=int, required=True, nullable=False, location='json') + parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("desc", type=str, location="json") + parser.add_argument("copyright", type=str, location="json") + parser.add_argument("privacy_policy", type=str, location="json") + parser.add_argument("custom_disclaimer", type=str, location="json") + parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = App.query.filter(App.id == args['app_id']).first() + app = App.query.filter(App.id == args["app_id"]).first() if not app: raise NotFound(f'App \'{args["app_id"]}\' is not found') site = app.site if not site: - desc = args['desc'] if args['desc'] else '' - copy_right = args['copyright'] if args['copyright'] else '' - privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = args["desc"] if args["desc"] else "" + copy_right = args["copyright"] if args["copyright"] else "" + privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" + custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" else: - desc = site.description if site.description else \ - args['desc'] if args['desc'] else '' - copy_right = site.copyright if site.copyright else \ - args['copyright'] if args['copyright'] else '' - privacy_policy = site.privacy_policy if site.privacy_policy else \ - args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ - args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = site.description if site.description else args["desc"] if args["desc"] else "" + copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" + privacy_policy = ( + site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" + ) + custom_disclaimer = ( + site.custom_disclaimer + if site.custom_disclaimer + else args["custom_disclaimer"] + if args["custom_disclaimer"] + else "" + ) - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if not recommended_app: recommended_app = RecommendedApp( @@ -83,9 +87,9 @@ def post(self): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args['language'], - category=args['category'], - position=args['position'] + language=args["language"], + category=args["category"], + position=args["position"], ) db.session.add(recommended_app) @@ -93,21 +97,21 @@ def post(self): app.is_public = True db.session.commit() - return {'result': 'success'}, 201 + return {"result": "success"}, 201 else: recommended_app.description = desc recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args['language'] - recommended_app.category = args['category'] - recommended_app.position = args['position'] + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] app.is_public = True db.session.commit() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): @@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() if not recommended_app: - return {'result': 'success'}, 204 + return {"result": "success"}, 204 app = App.query.filter(App.id == recommended_app.app_id).first() if app: app.is_public = False installed_apps = InstalledApp.query.filter( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id + InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id ).all() for installed_app in installed_apps: @@ -133,8 +136,8 @@ def delete(self, app_id): db.session.delete(recommended_app) db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') -api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/') +api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") +api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 324b8311752898..3f5e1adca23791 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -14,26 +14,21 @@ from .wraps import account_initialization_required api_key_fields = { - 'id': fields.String, - 'type': fields.String, - 'token': fields.String, - 'last_used_at': TimestampField, - 'created_at': TimestampField + "id": fields.String, + "type": fields.String, + "token": fields.String, + "last_used_at": TimestampField, + "created_at": TimestampField, } -api_key_list = { - 'data': fields.List(fields.Nested(api_key_fields), attribute="items") -} +api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} def _get_resource(resource_id, tenant_id, resource_model): - resource = resource_model.query.filter_by( - id=resource_id, tenant_id=tenant_id - ).first() + resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() if resource is None: - flask_restful.abort( - 404, message=f"{resource_model.__name__} not found.") + flask_restful.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource): @marshal_with(api_key_list) def get(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - all() + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .all() + ) return {"items": keys} @marshal_with(api_key_fields) def post(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource): def delete(self, resource_id, api_key_id): resource_id = str(resource_id) api_key_id = str(api_key_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + getattr(ApiToken, self.resource_id_field) == resource_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' - token_prefix = 'app-' + resource_id_field = "app_id" + token_prefix = "app-" class AppApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' + resource_id_field = "app_id" class DatasetApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' - token_prefix = 'ds-' + resource_id_field = "dataset_id" + token_prefix = "ds-" class DatasetApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' + resource_id_field = "dataset_id" -api.add_resource(AppApiKeyListResource, '/apps//api-keys') -api.add_resource(AppApiKeyResource, - '/apps//api-keys/') -api.add_resource(DatasetApiKeyListResource, - '/datasets//api-keys') -api.add_resource(DatasetApiKeyResource, - '/datasets//api-keys/') +api.add_resource(AppApiKeyListResource, "/apps//api-keys") +api.add_resource(AppApiKeyResource, "/apps//api-keys/") +api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") +api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index fa2b3807e82778..e7346bdf1dd91b 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -8,19 +8,18 @@ class AdvancedPromptTemplateList(Resource): - @setup_required @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, location='args') - parser.add_argument('model_mode', type=str, required=True, location='args') - parser.add_argument('has_context', type=str, required=False, default='true', location='args') - parser.add_argument('model_name', type=str, required=True, location='args') + parser.add_argument("app_mode", type=str, required=True, location="args") + parser.add_argument("model_mode", type=str, required=True, location="args") + parser.add_argument("has_context", type=str, required=False, default="true", location="args") + parser.add_argument("model_name", type=str, required=True, location="args") args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) -api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file + +api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index aee367276c0777..51899da7052111 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -18,15 +18,12 @@ class AgentLogApi(Resource): def get(self, app_model): """Get agent logs""" parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='args') - parser.add_argument('conversation_id', type=uuid_value, required=True, location='args') + parser.add_argument("message_id", type=uuid_value, required=True, location="args") + parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") args = parser.parse_args() - return AgentService.get_agent_logs( - app_model, - args['conversation_id'], - args['message_id'] - ) - -api.add_resource(AgentLogApi, '/apps//agent/logs') \ No newline at end of file + return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + + +api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 1ac8e60dcd2613..1ea1c82679defe 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -21,24 +21,23 @@ class AnnotationReplyActionApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') - parser.add_argument('embedding_provider_name', required=True, type=str, location='json') - parser.add_argument('embedding_model_name', required=True, type=str, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") + parser.add_argument("embedding_provider_name", required=True, type=str, location="json") + parser.add_argument("embedding_model_name", required=True, type=str, location="json") args = parser.parse_args() - if action == 'enable': + if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_id) - elif action == 'disable': + elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) else: - raise ValueError('Unsupported annotation reply action') + raise ValueError("Unsupported annotation reply action") return result, 200 @@ -47,8 +46,7 @@ class AppAnnotationSettingDetailApi(Resource): @login_required @account_initialization_required def get(self, app_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) @@ -61,15 +59,14 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required def post(self, app_id, annotation_setting_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) annotation_setting_id = str(annotation_setting_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) @@ -80,29 +77,24 @@ class AnnotationReplyActionStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) + app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) + error_msg = "" + if job_status == "error": + app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationListApi(Resource): @@ -110,22 +102,21 @@ class AnnotationListApi(Resource): @login_required @account_initialization_required def get(self, app_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - keyword = request.args.get('keyword', default=None, type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + keyword = request.args.get("keyword", default=None, type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) response = { - 'data': marshal(annotation_list, annotation_fields), - 'has_more': len(annotation_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_list, annotation_fields), + "has_more": len(annotation_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 @@ -135,15 +126,12 @@ class AnnotationExportApi(Resource): @login_required @account_initialization_required def get(self, app_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = { - 'data': marshal(annotation_list, annotation_fields) - } + response = {"data": marshal(annotation_list, annotation_fields)} return response, 200 @@ -151,17 +139,16 @@ class AnnotationCreateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation @@ -171,18 +158,17 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) return annotation @@ -191,37 +177,35 @@ def post(self, app_id, annotation_id): @login_required @account_initialization_required def delete(self, app_id, annotation_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class AnnotationBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() app_id = str(app_id) # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) @@ -230,28 +214,23 @@ class AnnotationBatchImportStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + error_msg = "" + if job_status == "error": + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) error_msg = redis_client.get(indexing_error_msg_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationHitHistoryListApi(Resource): @@ -259,34 +238,35 @@ class AnnotationHitHistoryListApi(Resource): @login_required @account_initialization_required def get(self, app_id, annotation_id): - # The role of the current user in the table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) annotation_id = str(annotation_id) - annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, - page, limit) + annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( + app_id, annotation_id, page, limit + ) response = { - 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), - 'has_more': len(annotation_hit_history_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), + "has_more": len(annotation_hit_history_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response -api.add_resource(AnnotationReplyActionApi, '/apps//annotation-reply/') -api.add_resource(AnnotationReplyActionStatusApi, - '/apps//annotation-reply//status/') -api.add_resource(AnnotationListApi, '/apps//annotations') -api.add_resource(AnnotationExportApi, '/apps//annotations/export') -api.add_resource(AnnotationUpdateDeleteApi, '/apps//annotations/') -api.add_resource(AnnotationBatchImportApi, '/apps//annotations/batch-import') -api.add_resource(AnnotationBatchImportStatusApi, '/apps//annotations/batch-import-status/') -api.add_resource(AnnotationHitHistoryListApi, '/apps//annotations//hit-histories') -api.add_resource(AppAnnotationSettingDetailApi, '/apps//annotation-setting') -api.add_resource(AppAnnotationSettingUpdateApi, '/apps//annotation-settings/') +api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") +api.add_resource( + AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" +) +api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationExportApi, "/apps//annotations/export") +api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") +api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") +api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") +api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") +api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") +api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2f304b970c6050..cc9c8b31cb6be7 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -18,27 +18,35 @@ from services.app_dsl_service import AppDslService from services.app_service import AppService -ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] +ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] class AppListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): """Get app list""" + def uuid_list(value): try: - return [str(uuid.UUID(v)) for v in value.split(',')] + return [str(uuid.UUID(v)) for v in value.split(",")] except ValueError: abort(400, message="Invalid UUID format in tag_ids.") + parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) - parser.add_argument('name', type=str, location='args', required=False) - parser.add_argument('tag_ids', type=uuid_list, location='args', required=False) + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "mode", + type=str, + choices=["chat", "workflow", "agent-chat", "channel", "all"], + default="all", + location="args", + required=False, + ) + parser.add_argument("name", type=str, location="args", required=False) + parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) args = parser.parse_args() @@ -46,7 +54,7 @@ def uuid_list(value): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) if not app_pagination: - return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False} + return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return marshal(app_pagination, app_pagination_fields) @@ -54,22 +62,23 @@ def uuid_list(value): @login_required @account_initialization_required @marshal_with(app_detail_fields) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Create app""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if 'mode' not in args or args['mode'] is None: + if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") app_service = AppService() @@ -83,7 +92,7 @@ class AppImportApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -91,18 +100,16 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=args['data'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user ) return app, 201 @@ -113,7 +120,7 @@ class AppImportFromUrlApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app from url""" # The role of the current user in the ta table must be admin, owner, or editor @@ -121,25 +128,21 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, - url=args['url'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user ) return app, 201 class AppApi(Resource): - @setup_required @login_required @account_initialization_required @@ -163,13 +166,14 @@ def put(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('max_active_requests', type=int, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() @@ -190,7 +194,7 @@ def delete(self, app_model): app_service = AppService() app_service.delete_app(app_model) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppCopyApi(Resource): @@ -206,18 +210,16 @@ def post(self, app_model): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() data = AppDslService.export_dsl(app_model=app_model, include_secret=True) app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=data, - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user ) return app, 201 @@ -236,12 +238,10 @@ def get(self, app_model): # Add include_secret params parser = reqparse.RequestParser() - parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') + parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") args = parser.parse_args() - return { - "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) - } + return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} class AppNameApi(Resource): @@ -254,13 +254,13 @@ def post(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get('name')) + app_model = app_service.update_app_name(app_model, args.get("name")) return app_model @@ -275,14 +275,14 @@ def post(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) + app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) return app_model @@ -297,13 +297,13 @@ def post(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_site', type=bool, required=True, location='json') + parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) + app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) return app_model @@ -318,13 +318,13 @@ def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_api', type=bool, required=True, location='json') + parser.add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) + app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) return app_model @@ -335,9 +335,7 @@ class AppTraceApi(Resource): @account_initialization_required def get(self, app_id): """Get app trace""" - app_trace_config = OpsTraceManager.get_app_tracing_config( - app_id=app_id - ) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) return app_trace_config @@ -349,27 +347,27 @@ def post(self, app_id): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('enabled', type=bool, required=True, location='json') - parser.add_argument('tracing_provider', type=str, required=True, location='json') + parser.add_argument("enabled", type=bool, required=True, location="json") + parser.add_argument("tracing_provider", type=str, required=True, location="json") args = parser.parse_args() OpsTraceManager.update_app_tracing_config( app_id=app_id, - enabled=args['enabled'], - tracing_provider=args['tracing_provider'], + enabled=args["enabled"], + tracing_provider=args["tracing_provider"], ) return {"result": "success"} -api.add_resource(AppListApi, '/apps') -api.add_resource(AppImportApi, '/apps/import') -api.add_resource(AppImportFromUrlApi, '/apps/import/url') -api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopyApi, '/apps//copy') -api.add_resource(AppExportApi, '/apps//export') -api.add_resource(AppNameApi, '/apps//name') -api.add_resource(AppIconApi, '/apps//icon') -api.add_resource(AppSiteStatus, '/apps//site-enable') -api.add_resource(AppApiStatus, '/apps//api-enable') -api.add_resource(AppTraceApi, '/apps//trace') +api.add_resource(AppListApi, "/apps") +api.add_resource(AppImportApi, "/apps/import") +api.add_resource(AppImportFromUrlApi, "/apps/import/url") +api.add_resource(AppApi, "/apps/") +api.add_resource(AppCopyApi, "/apps//copy") +api.add_resource(AppExportApi, "/apps//export") +api.add_resource(AppNameApi, "/apps//name") +api.add_resource(AppIconApi, "/apps//icon") +api.add_resource(AppSiteStatus, "/apps//site-enable") +api.add_resource(AppApiStatus, "/apps//api-enable") +api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 1de08afa4e08b9..437a6a7b3865b6 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): - file = request.files['file'] + file = request.files["file"] try: response = AudioService.transcript_asr( @@ -85,31 +85,31 @@ def post(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( - 'voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - text=text, - message_id=message_id, - voice=voice - ) + response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -145,12 +145,12 @@ class TextModesApi(Resource): def get(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('language', type=str, required=True, location='args') + parser.add_argument("language", type=str, required=True, location="args") args = parser.parse_args() response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args['language'], + language=args["language"], ) return response @@ -179,6 +179,6 @@ def get(self, app_model): raise InternalServerError() -api.add_resource(ChatMessageAudioApi, '/apps//audio-to-text') -api.add_resource(ChatMessageTextApi, '/apps//text-to-audio') -api.add_resource(TextModesApi, '/apps//text-to-audio/voices') +api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") +api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") +api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 61582536fdbe1d..53de51c24d798a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -17,6 +17,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -31,37 +32,33 @@ from libs.login import login_required from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.llm import InvokeRateLimitError # define completion message api for user class CompletionMessageApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -97,7 +94,7 @@ def post(self, app_model, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatMessageApi(Resource): @@ -107,27 +104,23 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -144,6 +137,8 @@ def post(self, app_model): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) except (ValueError, AppInvokeQuotaExceededError) as e: @@ -163,10 +158,10 @@ def post(self, app_model, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionMessageApi, '/apps//completion-messages') -api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') -api.add_resource(ChatMessageApi, '/apps//chat-messages') -api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') +api.add_resource(CompletionMessageApi, "/apps//completion-messages") +api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") +api.add_resource(ChatMessageApi, "/apps//chat-messages") +api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 96cd9a6ea141eb..c3aac6690e4c3d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -26,34 +26,32 @@ class CompletionConversationApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) def get(self, app_model): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") - if args['keyword']: - query = query.join( - Message, Message.conversation_id == Conversation.id - ).filter( + if args["keyword"]: + query = query.join(Message, Message.conversation_id == Conversation.id).filter( or_( - Message.query.ilike('%{}%'.format(args['keyword'])), - Message.answer.ilike('%{}%'.format(args['keyword'])) + Message.query.ilike("%{}%".format(args["keyword"])), + Message.answer.ilike("%{}%".format(args["keyword"])), ) ) @@ -61,8 +59,8 @@ def get(self, app_model): timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -70,8 +68,8 @@ def get(self, app_model): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -79,36 +77,32 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class CompletionConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) def get(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) @@ -119,12 +113,15 @@ def get(self, app_model, conversation_id): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def delete(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -132,34 +129,41 @@ def delete(self, app_model, conversation_id): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ChatConversationApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) def get(self, app_model): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') - parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") + parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() subquery = ( db.session.query( - Conversation.id.label('conversation_id'), - EndUser.session_id.label('from_end_user_session_id') + Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") ) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() @@ -167,19 +171,19 @@ def get(self, app_model): query = db.select(Conversation).where(Conversation.app_id == app_model.id) - if args['keyword']: - keyword_filter = '%{}%'.format(args['keyword']) - query = query.join( - Message, Message.conversation_id == Conversation.id, - ).join( - subquery, subquery.c.conversation_id == Conversation.id - ).filter( + if args["keyword"]: + keyword_filter = "%{}%".format(args["keyword"]) + message_subquery = ( + db.session.query(Message.conversation_id) + .filter(or_(Message.query.ilike(keyword_filter), Message.answer.ilike(keyword_filter))) + .subquery() + ) + query = query.join(subquery, subquery.c.conversation_id == Conversation.id).filter( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), + Conversation.id.in_(message_subquery), Conversation.name.ilike(keyword_filter), Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter) + subquery.c.from_end_user_session_id.ilike(keyword_filter), ), ) @@ -187,8 +191,8 @@ def get(self, app_model): timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -196,8 +200,8 @@ def get(self, app_model): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -205,47 +209,53 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) - if args['message_count_gte'] and args['message_count_gte'] >= 1: + if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( query.options(joinedload(Conversation.messages)) .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) - .having(func.count(Message.id) >= args['message_count_gte']) + .having(func.count(Message.id) >= args["message_count_gte"]) ) if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - query = query.order_by(Conversation.created_at.desc()) + match args["sort_by"]: + case "created_at": + query = query.order_by(Conversation.created_at.asc()) + case "-created_at": + query = query.order_by(Conversation.created_at.desc()) + case "updated_at": + query = query.order_by(Conversation.updated_at.asc()) + case "-updated_at": + query = query.order_by(Conversation.updated_at.desc()) + case _: + query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class ChatConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_detail_fields) def get(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) @@ -256,12 +266,15 @@ def get(self, app_model, conversation_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required def delete(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -269,18 +282,21 @@ def delete(self, app_model, conversation_id): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, '/apps//completion-conversations') -api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') -api.add_resource(ChatConversationApi, '/apps//chat-conversations') -api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') +api.add_resource(CompletionConversationApi, "/apps//completion-conversations") +api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") +api.add_resource(ChatConversationApi, "/apps//chat-conversations") +api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") def _get_conversation(app_model, conversation_id): - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py new file mode 100644 index 00000000000000..23b234dac9d397 --- /dev/null +++ b/api/controllers/console/app/conversation_variables.py @@ -0,0 +1,61 @@ +from flask_restful import Resource, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from fields.conversation_variable_fields import paginated_conversation_variable_fields +from libs.login import login_required +from models import ConversationVariable +from models.model import AppMode + + +class ConversationVariablesApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + @marshal_with(paginated_conversation_variable_fields) + def get(self, app_model): + parser = reqparse.RequestParser() + parser.add_argument("conversation_id", type=str, location="args") + args = parser.parse_args() + + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .order_by(ConversationVariable.created_at) + ) + if args["conversation_id"]: + stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) + else: + raise ValueError("conversation_id is required") + + # NOTE: This is a temporary solution to avoid performance issues. + page = 1 + page_size = 100 + stmt = stmt.limit(page_size).offset((page - 1) * page_size) + + with Session(db.engine) as session: + rows = session.scalars(stmt).all() + + return { + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ + { + "created_at": row.created_at, + "updated_at": row.updated_at, + **row.to_variable().model_dump(), + } + for row in rows + ], + } + + +api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index f6feed12217a85..1559f82d6ea142 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -2,116 +2,128 @@ class AppNotFoundError(BaseHTTPException): - error_code = 'app_not_found' + error_code = "app_not_found" description = "App not found." code = 404 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class DraftWorkflowNotExist(BaseHTTPException): - error_code = 'draft_workflow_not_exist' + error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." code = 400 class DraftWorkflowNotSync(BaseHTTPException): - error_code = 'draft_workflow_not_sync' + error_code = "draft_workflow_not_sync" description = "Workflow graph might have been modified, please refresh and resubmit." code = 400 class TracingConfigNotExist(BaseHTTPException): - error_code = 'trace_config_not_exist' + error_code = "trace_config_not_exist" description = "Trace config not exist." code = 400 class TracingConfigIsExist(BaseHTTPException): - error_code = 'trace_config_is_exist' + error_code = "trace_config_is_exist" description = "Trace config is exist." code = 400 class TracingConfigCheckError(BaseHTTPException): - error_code = 'trace_config_check_error' + error_code = "trace_config_check_error" description = "Invalid Credentials." code = 400 + + +class InvokeRateLimitError(BaseHTTPException): + """Raised when the Invoke returns rate limit error.""" + + error_code = "rate_limit_error" + description = "Rate Limit Error" + code = 429 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6803775e20dfb8..3d1e6b7a37cacd 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -24,21 +24,21 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('instruction', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json') - parser.add_argument('no_variable', type=bool, required=True, default=False, location='json') + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, - instruction=args['instruction'], - model_config=args['model_config'], - no_variable=args['no_variable'], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=args["no_variable"], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -52,4 +52,4 @@ def post(self): return rules -api.add_resource(RuleGenerateApi, '/rule-generate') +api.add_resource(RuleGenerateApi, "/rule-generate") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 636c071795940b..fe06201982374a 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -33,9 +33,9 @@ class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_detail_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_detail_fields)), } @setup_required @@ -45,55 +45,69 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - conversation = db.session.query(Conversation).filter( - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") - if args['first_id']: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() + if args["first_id"]: + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .first() + ) if not first_message: raise NotFound("First message not found") - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) has_more = False - if len(history_messages) == args['limit']: + if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=args['limit'], - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) class MessageFeedbackApi(Resource): @@ -103,61 +117,57 @@ class MessageFeedbackApi(Resource): @get_app_model def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('message_id', required=True, type=uuid_value, location='json') - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("message_id", required=True, type=uuid_value, location="json") + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args['message_id']) + message_id = str(args["message_id"]) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") feedback = message.admin_feedback - if not args['rating'] and feedback: + if not args["rating"] and feedback: db.session.delete(feedback) - elif args['rating'] and feedback: - feedback.rating = args['rating'] - elif not args['rating'] and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + elif args["rating"] and feedback: + feedback.rating = args["rating"] + elif not args["rating"] and not feedback: + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=args['rating'], - from_source='admin', - from_account_id=current_user.id + rating=args["rating"], + from_source="admin", + from_account_id=current_user.id, ) db.session.add(feedback) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class MessageAnnotationApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @get_app_model @marshal_with(annotation_fields) def post(self, app_model): - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('message_id', required=False, type=uuid_value, location='json') - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') - parser.add_argument('annotation_reply', required=False, type=dict, location='json') + parser.add_argument("message_id", required=False, type=uuid_value, location="json") + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + parser.add_argument("annotation_reply", required=False, type=dict, location="json") args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) @@ -170,11 +180,9 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app_model.id - ).count() + count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() - return {'count': count} + return {"count": count} class MessageSuggestedQuestionApi(Resource): @@ -187,10 +195,7 @@ def get(self, app_model, message_id): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - message_id=message_id, - user=current_user, - invoke_from=InvokeFrom.DEBUGGER + app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER ) except MessageNotExistsError: raise NotFound("Message not found") @@ -210,7 +215,7 @@ def get(self, app_model, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} class MessageApi(Resource): @@ -222,10 +227,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") @@ -233,9 +235,9 @@ def get(self, app_model, message_id): return message -api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') -api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') -api.add_resource(MessageFeedbackApi, '/apps//feedbacks') -api.add_resource(MessageAnnotationApi, '/apps//annotations') -api.add_resource(MessageAnnotationCountApi, '/apps//annotations/count') -api.add_resource(MessageApi, '/apps//messages/', endpoint='console_message') +api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") +api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") +api.add_resource(MessageFeedbackApi, "/apps//feedbacks") +api.add_resource(MessageAnnotationApi, "/apps//annotations") +api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") +api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index c8df879a29ca42..f5068a4cd8fcab 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -19,37 +19,35 @@ class ModelConfigResource(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): - """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - config=request.json, - app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( app_id=app_model.id, + created_by=current_user.id, + updated_by=current_user.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app_model.app_model_config_id - ).first() + original_app_model_config: AppModelConfig = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} masked_parameter_map = {} tool_map = {} - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue @@ -66,7 +64,7 @@ def post(self, app_model): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) except Exception as e: continue @@ -79,18 +77,18 @@ def post(self, app_model): parameters = {} masked_parameter = {} - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" masked_parameter_map[key] = masked_parameter parameter_map[key] = parameters tool_map[key] = tool_runtime # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" if key in tool_map: tool_runtime = tool_map[key] else: @@ -108,7 +106,7 @@ def post(self, app_model): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) manager.delete_tool_parameters_cache() @@ -116,15 +114,17 @@ def post(self, app_model): if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - + for masked_key, masked_value in masked_parameter_map[key].items(): - if masked_key in agent_tool_entity.tool_parameters and \ - agent_tool_entity.tool_parameters[masked_key] == masked_value: + if ( + masked_key in agent_tool_entity.tool_parameters + and agent_tool_entity.tool_parameters[masked_key] == masked_value + ): agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) # encrypt parameters if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) # update app model config new_app_model_config.agent_mode = json.dumps(agent_mode) @@ -135,12 +135,9 @@ def post(self, app_model): app_model.app_model_config_id = new_app_model_config.id db.session.commit() - app_model_config_was_updated.send( - app_model, - app_model_config=new_app_model_config - ) + app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index c0cf7b9e33f32b..374bd2b8157027 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def get(self, app_id): parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - trace_config = OpsService.get_tracing_app_config( - app_id=app_id, tracing_provider=args['tracing_provider'] - ) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not trace_config: return {"has_not_configured": True} return trace_config @@ -37,19 +35,17 @@ def get(self, app_id): def post(self, app_id): """Create a new trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.create_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigIsExist() - if result.get('error'): + if result.get("error"): raise TracingConfigCheckError() return result except Exception as e: @@ -61,15 +57,13 @@ def post(self, app_id): def patch(self, app_id): """Update an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.update_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigNotExist() @@ -83,14 +77,11 @@ def patch(self, app_id): def delete(self, app_id): """Delete an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - result = OpsService.delete_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'] - ) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: raise TracingConfigNotExist() return {"result": "success"} @@ -98,4 +89,4 @@ def delete(self, app_id): raise e -api.add_resource(TraceAppConfigApi, '/apps//trace-config') +api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 6aa9f0b475f161..f936642acd1d4e 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,3 +1,5 @@ +from datetime import datetime, timezone + from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -15,22 +17,23 @@ def parse_app_site_args(): parser = reqparse.RequestParser() - parser.add_argument('title', type=str, required=False, location='json') - parser.add_argument('icon', type=str, required=False, location='json') - parser.add_argument('icon_background', type=str, required=False, location='json') - parser.add_argument('description', type=str, required=False, location='json') - parser.add_argument('default_language', type=supported_language, required=False, location='json') - parser.add_argument('chat_color_theme', type=str, required=False, location='json') - parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json') - parser.add_argument('customize_domain', type=str, required=False, location='json') - parser.add_argument('copyright', type=str, required=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, location='json') - parser.add_argument('custom_disclaimer', type=str, required=False, location='json') - parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], - required=False, - location='json') - parser.add_argument('prompt_public', type=bool, required=False, location='json') - parser.add_argument('show_workflow_steps', type=bool, required=False, location='json') + parser.add_argument("title", type=str, required=False, location="json") + parser.add_argument("icon_type", type=str, required=False, location="json") + parser.add_argument("icon", type=str, required=False, location="json") + parser.add_argument("icon_background", type=str, required=False, location="json") + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("default_language", type=supported_language, required=False, location="json") + parser.add_argument("chat_color_theme", type=str, required=False, location="json") + parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") + parser.add_argument("customize_domain", type=str, required=False, location="json") + parser.add_argument("copyright", type=str, required=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, location="json") + parser.add_argument("custom_disclaimer", type=str, required=False, location="json") + parser.add_argument( + "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" + ) + parser.add_argument("prompt_public", type=bool, required=False, location="json") + parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") return parser.parse_args() @@ -47,37 +50,37 @@ def post(self, app_model): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site). \ - filter(Site.app_id == app_model.id). \ - one_or_404() + site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ - 'title', - 'icon', - 'icon_background', - 'description', - 'default_language', - 'chat_color_theme', - 'chat_color_theme_inverted', - 'customize_domain', - 'copyright', - 'privacy_policy', - 'custom_disclaimer', - 'customize_token_strategy', - 'prompt_public', - 'show_workflow_steps' + "title", + "icon_type", + "icon", + "icon_background", + "description", + "default_language", + "chat_color_theme", + "chat_color_theme_inverted", + "customize_domain", + "copyright", + "privacy_policy", + "custom_disclaimer", + "customize_token_strategy", + "prompt_public", + "show_workflow_steps", ]: value = args.get(attr_name) if value is not None: setattr(site, attr_name, value) + site.updated_by = current_user.id + site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return site class AppSiteAccessTokenReset(Resource): - @setup_required @login_required @account_initialization_required @@ -94,10 +97,12 @@ def post(self, app_model): raise NotFound site.code = Site.generate_code(16) + site.updated_by = current_user.id + site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return site -api.add_resource(AppSite, '/apps//site') -api.add_resource(AppSiteAccessTokenReset, '/apps//site/access-token-reset') +api.add_resource(AppSite, "/apps//site") +api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index b882ffef34129e..81826a20d05544 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -16,8 +16,61 @@ from models.model import AppMode -class DailyConversationStatistic(Resource): +class DailyMessageStatistic(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): + account = current_user + + parser = reqparse.RequestParser() + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + args = parser.parse_args() + + sql_query = """ + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count + FROM messages where app_id = :app_id + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} + + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc + + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc + + sql_query += " GROUP BY date order by date" + + response_data = [] + + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query), arg_dict) + for i in rs: + response_data.append({"date": str(i.date), "message_count": i.message_count}) + + return jsonify({"data": response_data}) + +class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required @@ -26,58 +79,52 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'conversation_count': i.conversation_count - }) + response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTerminalsStatistic(Resource): - @setup_required @login_required @account_initialization_required @@ -86,54 +133,49 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTokenCostStatistic(Resource): @@ -145,58 +187,53 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, sum(total_price) as total_price FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - 'total_price': i.total_price, - 'currency': 'USD' - }) + response_data.append( + {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageSessionInteractionStatistic(Resource): @@ -208,8 +245,8 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -218,30 +255,30 @@ def get(self, app_model): FROM conversations c JOIN messages m ON c.id = m.conversation_id WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and c.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and c.created_at < :end" + arg_dict["end"] = end_datetime_utc sql_query += """ GROUP BY m.conversation_id) subquery @@ -250,18 +287,15 @@ def get(self, app_model): ORDER BY date""" response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class UserSatisfactionRateStatistic(Resource): @@ -273,57 +307,57 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count FROM messages m LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' WHERE m.app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and m.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and m.created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), - }) + response_data.append( + { + "date": str(i.date), + "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), + } + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageResponseTimeStatistic(Resource): @@ -335,56 +369,51 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(provider_response_latency) as latency FROM messages WHERE app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'latency': round(i.latency * 1000, 4) - }) + response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class TokensPerSecondStatistic(Resource): @@ -396,63 +425,59 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second FROM messages -WHERE app_id = :app_id''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} +WHERE app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'tps': round(i.tokens_per_second, 4) - }) - - return jsonify({ - 'data': response_data - }) - - -api.add_resource(DailyConversationStatistic, '/apps//statistics/daily-conversations') -api.add_resource(DailyTerminalsStatistic, '/apps//statistics/daily-end-users') -api.add_resource(DailyTokenCostStatistic, '/apps//statistics/token-costs') -api.add_resource(AverageSessionInteractionStatistic, '/apps//statistics/average-session-interactions') -api.add_resource(UserSatisfactionRateStatistic, '/apps//statistics/user-satisfaction-rate') -api.add_resource(AverageResponseTimeStatistic, '/apps//statistics/average-response-time') -api.add_resource(TokensPerSecondStatistic, '/apps//statistics/tokens-per-second') + response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) + + return jsonify({"data": response_data}) + + +api.add_resource(DailyMessageStatistic, "/apps//statistics/daily-messages") +api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") +api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") +api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") +api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") +api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") +api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") +api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 686ef7b4bebaaa..e44820f6345c48 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -64,49 +64,54 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - - content_type = request.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: parser = reqparse.RequestParser() - parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - parser.add_argument('hash', type=str, required=False, location='json') + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("features", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") # TODO: set this to required=True after frontend is updated - parser.add_argument('environment_variables', type=list, required=False, location='json') + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() - elif 'text/plain' in content_type: + elif "text/plain" in content_type: try: - data = json.loads(request.data.decode('utf-8')) - if 'graph' not in data or 'features' not in data: - raise ValueError('graph or features not found in data') + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") - if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict): - raise ValueError('graph or features is not a dict') + if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): + raise ValueError("graph or features is not a dict") args = { - 'graph': data.get('graph'), - 'features': data.get('features'), - 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables') + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), } except json.JSONDecodeError: - return {'message': 'Invalid JSON data'}, 400 + return {"message": "Invalid JSON data"}, 400 else: abort(415) workflow_service = WorkflowService() try: - environment_variables_list = args.get('environment_variables') or [] + environment_variables_list = args.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=args['graph'], - features=args['features'], - unique_hash=args.get('hash'), + graph=args["graph"], + features=args["features"], + unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -114,7 +119,7 @@ def post(self, app_model: App): return { "result": "success", "hash": workflow.unique_hash, - "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), } @@ -133,13 +138,11 @@ def post(self, app_model: App): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") args = parser.parse_args() workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, - data=args['data'], - account=current_user + app_model=app_model, data=args["data"], account=current_user ) return workflow @@ -157,21 +160,17 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') - parser.add_argument('query', type=str, required=True, location='json', default='') - parser.add_argument('files', type=list, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument("inputs", type=dict, location="json") + parser.add_argument("query", type=str, required=True, location="json", default="") + parser.add_argument("files", type=list, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -185,6 +184,7 @@ def post(self, app_model: App): logging.exception("internal server error.") raise InternalServerError() + class AdvancedChatDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -197,18 +197,14 @@ def post(self, app_model: App, node_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -222,6 +218,7 @@ def post(self, app_model: App, node_id: str): logging.exception("internal server error.") raise InternalServerError() + class WorkflowDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -234,18 +231,14 @@ def post(self, app_model: App, node_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -259,6 +252,7 @@ def post(self, app_model: App, node_id: str): logging.exception("internal server error.") raise InternalServerError() + class DraftWorkflowRunApi(Resource): @setup_required @login_required @@ -271,19 +265,15 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -306,12 +296,10 @@ def post(self, app_model: App, task_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) - return { - "result": "success" - } + return {"result": "success"} class DraftWorkflowNodeRunApi(Resource): @@ -327,24 +315,20 @@ def post(self, app_model: App, node_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() workflow_service = WorkflowService() workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, - node_id=node_id, - user_inputs=args.get('inputs'), - account=current_user + app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user ) return workflow_node_execution class PublishedWorkflowApi(Resource): - @setup_required @login_required @account_initialization_required @@ -357,7 +341,7 @@ def get(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) @@ -376,14 +360,11 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + workflow_service = WorkflowService() workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) - return { - "result": "success", - "created_at": TimestampField().format(workflow.created_at) - } + return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} class DefaultBlockConfigsApi(Resource): @@ -398,7 +379,7 @@ def get(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -416,24 +397,21 @@ def get(self, app_model: App, block_type: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('q', type=str, location='args') + parser.add_argument("q", type=str, location="args") args = parser.parse_args() filters = None - if args.get('q'): + if args.get("q"): try: - filters = json.loads(args.get('q')) + filters = json.loads(args.get("q")) except json.JSONDecodeError: - raise ValueError('Invalid filters') + raise ValueError("Invalid filters") # Get default block configs workflow_service = WorkflowService() - return workflow_service.get_default_block_config( - node_type=block_type, - filters=filters - ) + return workflow_service.get_default_block_config(node_type=block_type, filters=filters) class ConvertToWorkflowApi(Resource): @@ -450,40 +428,43 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if request.data: parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") args = parser.parse_args() else: args = {} # convert to workflow mode workflow_service = WorkflowService() - new_app_model = workflow_service.convert_to_workflow( - app_model=app_model, - account=current_user, - args=args - ) + new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) # return app id return { - 'new_app_id': new_app_model.id, + "new_app_id": new_app_model.id, } -api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') -api.add_resource(DraftWorkflowImportApi, '/apps//workflows/draft/import') -api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') -api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') -api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') -api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') -api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' - '/') -api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') +api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") +api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") +api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") +api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") +api.add_resource(DraftWorkflowNodeRunApi, "/apps//workflows/draft/nodes//run") +api.add_resource( + AdvancedChatDraftRunIterationNodeApi, + "/apps//advanced-chat/workflows/draft/iteration/nodes//run", +) +api.add_resource( + WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" +) +api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") +api.add_resource( + DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs" "/" +) +api.add_resource(ConvertToWorkflowApi, "/apps//convert-to-workflow") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6d1709ed8e65d9..dc962409cc4e86 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -22,20 +22,19 @@ def get(self, app_model: App): Get workflow app logs """ parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() # get paginate workflow app logs workflow_app_service = WorkflowAppService() workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( - app_model=app_model, - args=args + app_model=app_model, args=args ) return workflow_app_log_pagination -api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') +api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 35d982e37ce4e3..a055d03deb7879 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -28,15 +28,12 @@ def get(self, app_model: App): Get advanced chat app workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) return result @@ -52,15 +49,12 @@ def get(self, app_model: App): Get workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) return result @@ -98,12 +92,10 @@ def get(self, app_model: App, run_id): workflow_run_service = WorkflowRunService() node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) - return { - 'data': node_executions - } + return {"data": node_executions} -api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps//advanced-chat/workflow-runs') -api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') -api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') -api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') +api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") +api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") +api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") +api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 1d7dc395ff3b18..db2f6835898225 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -26,56 +26,56 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'runs': i.runs - }) + response_data.append({"date": str(i.date), "runs": i.runs}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTerminalsStatistic(Resource): @setup_required @@ -86,56 +86,56 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTokenCostStatistic(Resource): @setup_required @@ -146,58 +146,63 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SUM(workflow_runs.total_tokens) as token_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - }) + response_data.append( + { + "date": str(i.date), + "token_count": i.token_count, + } + ) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowAverageAppInteractionStatistic(Resource): @setup_required @@ -208,8 +213,8 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -229,50 +234,54 @@ def get(self, app_model): GROUP BY date, c.created_by) sub GROUP BY sub.date """ - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start') - arg_dict['start'] = start_datetime_utc + sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") + arg_dict["start"] = start_datetime_utc else: - sql_query = sql_query.replace('{{start}}', '') + sql_query = sql_query.replace("{{start}}", "") - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end') - arg_dict['end'] = end_datetime_utc + sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") + arg_dict["end"] = end_datetime_utc else: - sql_query = sql_query.replace('{{end}}', '') + sql_query = sql_query.replace("{{end}}", "") response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) - - return jsonify({ - 'data': response_data - }) - -api.add_resource(WorkflowDailyRunsStatistic, '/apps//workflow/statistics/daily-conversations') -api.add_resource(WorkflowDailyTerminalsStatistic, '/apps//workflow/statistics/daily-terminals') -api.add_resource(WorkflowDailyTokenCostStatistic, '/apps//workflow/statistics/token-costs') -api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps//workflow/statistics/average-app-interactions') + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) + + return jsonify({"data": response_data}) + + +api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") +api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") +api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") +api.add_resource( + WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" +) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index d61ab6d6ae8f28..5e0a4bc814633a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,24 +8,23 @@ from models.model import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - if not kwargs.get('app_id'): - raise ValueError('missing app_id in path parameters') + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") - app_id = kwargs.get('app_id') + app_id = kwargs.get("app_id") app_id = str(app_id) - del kwargs['app_id'] + del kwargs["app_id"] - app_model = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app_model: raise AppNotFoundError() @@ -44,9 +43,10 @@ def decorated_view(*args, **kwargs): mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model return view_func(*args, **kwargs) + return decorated_view if view is None: diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8efb55cdb64edc..8ba6b53e7ead2b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -17,60 +17,61 @@ class ActivateCheckApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') - parser.add_argument('email', type=email, required=False, nullable=True, location='args') - parser.add_argument('token', type=str, required=True, nullable=False, location='args') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") + parser.add_argument("email", type=email, required=False, nullable=True, location="args") + parser.add_argument("token", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - workspaceId = args['workspace_id'] - reg_email = args['email'] - token = args['token'] + workspaceId = args["workspace_id"] + reg_email = args["email"] + token = args["token"] invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} + return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} class ActivateApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') - parser.add_argument('email', type=email, required=False, nullable=True, location='json') - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') - parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('interface_language', type=supported_language, required=True, nullable=False, - location='json') - parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") + parser.add_argument("email", type=email, required=False, nullable=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" + ) + parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") args = parser.parse_args() - invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) + invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) + RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) - account = invitation['account'] - account.name = args['name'] + account = invitation["account"] + account.name = args["name"] # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() # encrypt password with salt - password_hashed = hash_password(args['password'], salt) + password_hashed = hash_password(args["password"], salt) base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ActivateCheckApi, '/activate/check') -api.add_resource(ActivateApi, '/activate') +api.add_resource(ActivateCheckApi, "/activate/check") +api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index f79b93b74f6df3..50db6eebc13d3d 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -19,18 +19,19 @@ def get(self): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) if data_source_api_key_bindings: return { - 'sources': [{ - 'id': data_source_api_key_binding.id, - 'category': data_source_api_key_binding.category, - 'provider': data_source_api_key_binding.provider, - 'disabled': data_source_api_key_binding.disabled, - 'created_at': int(data_source_api_key_binding.created_at.timestamp()), - 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()), - } - for data_source_api_key_binding in - data_source_api_key_bindings] + "sources": [ + { + "id": data_source_api_key_binding.id, + "category": data_source_api_key_binding.category, + "provider": data_source_api_key_binding.provider, + "disabled": data_source_api_key_binding.disabled, + "created_at": int(data_source_api_key_binding.created_at.timestamp()), + "updated_at": int(data_source_api_key_binding.updated_at.timestamp()), + } + for data_source_api_key_binding in data_source_api_key_bindings + ] } - return {'sources': []} + return {"sources": []} class ApiKeyAuthDataSourceBinding(Resource): @@ -42,16 +43,16 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ApiKeyAuthDataSourceBindingDelete(Resource): @@ -65,9 +66,9 @@ def delete(self, binding_id): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') -api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') -api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/') +api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") +api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") +api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 45cfa9d7ebcb1b..fd31e5ccc3b99c 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -17,13 +17,13 @@ def get_oauth_providers(): with current_app.app_context(): - notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion') + notion_oauth = NotionOAuth( + client_id=dify_config.NOTION_CLIENT_ID, + client_secret=dify_config.NOTION_CLIENT_SECRET, + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", + ) - OAUTH_PROVIDERS = { - 'notion': notion_oauth - } + OAUTH_PROVIDERS = {"notion": notion_oauth} return OAUTH_PROVIDERS @@ -37,18 +37,16 @@ def get(self, provider: str): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if dify_config.NOTION_INTEGRATION_TYPE == 'internal': + return {"error": "Invalid provider"}, 400 + if dify_config.NOTION_INTEGRATION_TYPE == "internal": internal_secret = dify_config.NOTION_INTERNAL_SECRET if not internal_secret: - return {'error': 'Internal secret is not set'}, + return ({"error": "Internal secret is not set"},) oauth_provider.save_internal_access_token(internal_secret) - return { 'data': '' } + return {"data": ""} else: auth_url = oauth_provider.get_authorization_url() - return { 'data': auth_url }, 200 - - + return {"data": auth_url}, 200 class OAuthDataSourceCallback(Resource): @@ -57,18 +55,18 @@ def get(self, provider: str): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}') - elif 'error' in request.args: - error = request.args.get('error') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") + elif "error" in request.args: + error = request.args.get("error") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") else: - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied') - + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") + class OAuthDataSourceBinding(Resource): def get(self, provider: str): @@ -76,17 +74,18 @@ def get(self, provider: str): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" + ) + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class OAuthDataSourceSync(Resource): @@ -100,18 +99,17 @@ def get(self, provider, binding_id): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 try: oauth_provider.sync_data_source(binding_id) except requests.exceptions.HTTPError as e: - logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') -api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/') -api.add_resource(OAuthDataSourceSync, '/oauth/data-source///sync') +api.add_resource(OAuthDataSource, "/oauth/data-source/") +api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") +api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") +api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 53dab3298fbff3..ea23e097d0bf3e 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -2,31 +2,30 @@ class ApiKeyAuthFailedError(BaseHTTPException): - error_code = 'auth_failed' + error_code = "auth_failed" description = "{message}" code = 500 class InvalidEmailError(BaseHTTPException): - error_code = 'invalid_email' + error_code = "invalid_email" description = "The email address is not valid." code = 400 class PasswordMismatchError(BaseHTTPException): - error_code = 'password_mismatch' + error_code = "password_mismatch" description = "The passwords do not match." code = 400 class InvalidTokenError(BaseHTTPException): - error_code = 'invalid_or_expired_token' + error_code = "invalid_or_expired_token" description = "The token is invalid or has expired." code = 400 class PasswordResetRateLimitExceededError(BaseHTTPException): - error_code = 'password_reset_rate_limit_exceeded' + error_code = "password_reset_rate_limit_exceeded" description = "Password reset rate limit exceeded. Try again later." code = 429 - diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d78be770abd094..0b01a4906adbcf 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -21,14 +21,13 @@ class ForgotPasswordSendEmailApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('email', type=str, required=True, location='json') + parser.add_argument("email", type=str, required=True, location="json") args = parser.parse_args() - email = args['email'] + email = args["email"] if not email_validate(email): raise InvalidEmailError() @@ -49,38 +48,36 @@ def post(self): class ForgotPasswordCheckApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: - return {'is_valid': False, 'email': None} - return {'is_valid': True, 'email': reset_data.get('email')} + return {"is_valid": False, "email": None} + return {"is_valid": True, "email": reset_data.get("email")} class ForgotPasswordResetApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - new_password = args['new_password'] - password_confirm = args['password_confirm'] + new_password = args["new_password"] + password_confirm = args["password_confirm"] if str(new_password).strip() != str(password_confirm).strip(): raise PasswordMismatchError() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: @@ -94,14 +91,14 @@ def post(self): password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = Account.query.filter_by(email=reset_data.get('email')).first() + account = Account.query.filter_by(email=reset_data.get("email")).first() account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password') -api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity') -api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets') +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index c135ece67ef86c..62837af2b9b0eb 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -20,37 +20,39 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() - parser.add_argument('email', type=email, required=True, location='json') - parser.add_argument('password', type=valid_password, required=True, location='json') - parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") args = parser.parse_args() # todo: Verify the recaptcha try: - account = AccountService.authenticate(args['email'], args['password']) + account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError as e: - return {'code': 'unauthorized', 'message': str(e)}, 401 + return {"code": "unauthorized", "message": str(e)}, 401 # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } token = AccountService.login(account, ip_address=get_remote_ip(request)) - return {'result': 'success', 'data': token} + return {"result": "success", "data": token} class LogoutApi(Resource): - @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get('Authorization', '').split(' ')[1] + token = request.headers.get("Authorization", "").split(" ")[1] AccountService.logout(account=account, token=token) flask_login.logout_user() - return {'result': 'success'} + return {"result": "success"} class ResetPasswordApi(Resource): @@ -80,11 +82,11 @@ def get(self): # 'subject': 'Reset your Dify password', # 'html': """ #

Dear User,

- #

The Dify team has generated a new password for you, details as follows:

+ #

The Dify team has generated a new password for you, details as follows:

#

{new_password}

#

Please change your password to log in as soon as possible.

#

Regards,

- #

The Dify Team

+ #

The Dify Team

# """ # } @@ -101,8 +103,8 @@ def get(self): # # handle error # pass - return {'result': 'success'} + return {"result": "success"} -api.add_resource(LoginApi, '/login') -api.add_resource(LogoutApi, '/logout') +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4a651bfe7b009e..ae1b49f3ecf448 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -25,7 +25,7 @@ def get_oauth_providers(): github_oauth = GitHubOAuth( client_id=dify_config.GITHUB_CLIENT_ID, client_secret=dify_config.GITHUB_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", ) if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: google_oauth = None @@ -33,10 +33,10 @@ def get_oauth_providers(): google_oauth = GoogleOAuth( client_id=dify_config.GOOGLE_CLIENT_ID, client_secret=dify_config.GOOGLE_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} + OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} return OAUTH_PROVIDERS @@ -47,7 +47,7 @@ def get(self, provider: str): oauth_provider = OAUTH_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 auth_url = oauth_provider.get_authorization_url() return redirect(auth_url) @@ -59,20 +59,20 @@ def get(self, provider: str): with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 - code = request.args.get('code') + code = request.args.get("code") try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.HTTPError as e: - logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') - return {'error': 'OAuth process failed'}, 400 + logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + return {"error": "OAuth process failed"}, 400 account = _generate_account(provider, user_info) # Check account status if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - return {'error': 'Account is banned or closed.'}, 403 + return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -83,7 +83,7 @@ def get(self, provider: str): token = AccountService.login(account, ip_address=get_remote_ip(request)) - return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: # Create account - account_name = user_info.name if user_info.name else 'Dify' + account_name = user_info.name if user_info.name else "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) @@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -api.add_resource(OAuthLogin, '/oauth/login/') -api.add_resource(OAuthCallback, '/oauth/authorize/') +api.add_resource(OAuthLogin, "/oauth/login/") +api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 72a6129efa3e4d..9a1d9148696349 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -9,28 +9,24 @@ class Subscription(Resource): - @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): - parser = reqparse.RequestParser() - parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) - parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) + parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) + parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args['plan'], - args['interval'], - current_user.email, - current_user.current_tenant_id) + return BillingService.get_subscription( + args["plan"], args["interval"], current_user.email, current_user.current_tenant_id + ) class Invoices(Resource): - @setup_required @login_required @account_initialization_required @@ -40,5 +36,5 @@ def get(self): return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) -api.add_resource(Subscription, '/billing/subscription') -api.add_resource(Invoices, '/billing/invoices') +api.add_resource(Subscription, "/billing/subscription") +api.add_resource(Invoices, "/billing/invoices") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0ca0f0a85653dc..0e1acab946ae5f 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -22,19 +22,22 @@ class DataSourceApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = db.session.query(DataSourceOauthBinding).filter( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.disabled == False - ).all() + data_source_integrates = ( + db.session.query(DataSourceOauthBinding) + .filter( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.disabled == False, + ) + .all() + ) - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" providers = ["notion"] @@ -44,26 +47,30 @@ def get(self): existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) if existing_integrates: for existing_integrate in list(existing_integrates): - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'disabled': existing_integrate.disabled, - 'source_info': existing_integrate.source_info, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "disabled": existing_integrate.disabled, + "source_info": existing_integrate.source_info, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'source_info': None, - 'is_bound': False, - 'disabled': None, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) - return {'data': integrate_data}, 200 + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "source_info": None, + "is_bound": False, + "disabled": None, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data}, 200 @setup_required @login_required @@ -71,92 +78,82 @@ def get(self): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceOauthBinding.query.filter_by( - id=binding_id - ).first() + data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() if data_source_binding is None: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") # enable binding - if action == 'enable': + if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is not disabled.') + raise ValueError("Data source is not disabled.") # disable binding - if action == 'disable': + if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is disabled.') - return {'result': 'success'}, 200 + raise ValueError("Data source is disabled.") + return {"result": "success"}, 200 class DataSourceNotionListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): - dataset_id = request.args.get('dataset_id', default=None, type=str) + dataset_id = request.args.get("dataset_id", default=None, type=str) exist_page_ids = [] # import notion in the exist dataset if dataset_id: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') - if dataset.data_source_type != 'notion_import': - raise ValueError('Dataset is not notion type.') + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") documents = Document.query.filter_by( dataset_id=dataset_id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) + exist_page_ids.append(data_source_info["notion_page_id"]) # get all authorized pages data_source_bindings = DataSourceOauthBinding.query.filter_by( - tenant_id=current_user.current_tenant_id, - provider='notion', - disabled=False + tenant_id=current_user.current_tenant_id, provider="notion", disabled=False ).all() if not data_source_bindings: - return { - 'notion_info': [] - }, 200 + return {"notion_info": []}, 200 pre_import_info_list = [] for data_source_binding in data_source_bindings: source_info = data_source_binding.source_info - pages = source_info['pages'] + pages = source_info["pages"] # Filter out already bound pages for page in pages: - if page['page_id'] in exist_page_ids: - page['is_bound'] = True + if page["page_id"] in exist_page_ids: + page["is_bound"] = True else: - page['is_bound'] = False + page["is_bound"] = False pre_import_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, } pre_import_info_list.append(pre_import_info) - return { - 'notion_info': pre_import_info_list - }, 200 + return {"notion_info": pre_import_info_list}, 200 class DataSourceNotionApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,64 +163,67 @@ def get(self, workspace_id, page_id, page_type): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) text_docs = extractor.extract() - return { - 'content': "\n".join([doc.page_content for doc in text_docs]) - }, 200 + return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args['notion_info_list'] + notion_info_list = args["notion_info_list"] extract_settings = [] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + ) return response, 200 class DataSourceNotionDatasetSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -240,7 +240,6 @@ def get(self, dataset_id): class DataSourceNotionDocumentSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -258,10 +257,14 @@ def get(self, dataset_id, document_id): return 200 -api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') -api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') -api.add_resource(DataSourceNotionApi, - '/notion/workspaces//pages///preview', - '/datasets/notion-indexing-estimate') -api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync') -api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync') +api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") +api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") +api.add_resource( + DataSourceNotionApi, + "/notion/workspaces//pages///preview", + "/datasets/notion-indexing-estimate", +) +api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync") +api.add_resource( + DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync" +) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index be0281f07af27a..44c1390c14fa62 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -24,52 +24,47 @@ from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields from libs.login import login_required -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.model import ApiToken, UploadFile from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name def _validate_description_length(description): if len(description) > 400: - raise ValueError('Description cannot exceed 400 characters.') + raise ValueError("Description cannot exceed 400 characters.") return description class DatasetListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - ids = request.args.getlist('ids') - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: - datasets, total = DatasetService.get_datasets(page, limit, provider, - current_user.current_tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets( + page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids + ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -77,28 +72,22 @@ def get(self): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True + item["embedding_available"] = True - if item.get('permission') == 'partial_members': - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id']) - item.update({'partial_member_list': part_users_list}) + if item.get("permission") == "partial_members": + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) + item.update({"partial_member_list": part_users_list}) else: - item.update({'partial_member_list': []}) + item.update({"partial_member_list": []}) - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @setup_required @@ -106,13 +95,21 @@ def get(self): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -122,9 +119,10 @@ def post(self): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + indexing_technique=args["indexing_technique"], + account=current_user, + permission=DatasetPermissionEnum.ONLY_ME, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -142,42 +140,36 @@ def get(self, dataset_id): if dataset is None: raise NotFound("Dataset not found.") try: - DatasetService.check_dataset_permission( - dataset, current_user) + DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = marshal(dataset, dataset_detail_fields) - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data['indexing_technique'] == 'high_quality': + if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: - data['embedding_available'] = True + data["embedding_available"] = True else: - data['embedding_available'] = False + data["embedding_available"] = False else: - data['embedding_available'] = True + data["embedding_available"] = True - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) return data, 200 @@ -191,42 +183,49 @@ def patch(self, dataset_id): raise NotFound("Dataset not found.") parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('description', - location='json', store_missing=False, - type=_validate_description_length) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - 'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.' - ) - parser.add_argument('embedding_model', type=str, - location='json', help='Invalid embedding model.') - parser.add_argument('embedding_model_provider', type=str, - location='json', help='Invalid embedding model provider.') - parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') - parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.') + parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + ) + parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." + ) + parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") args = parser.parse_args() data = request.get_json() # check embedding model setting - if data.get('indexing_technique') == 'high_quality': - DatasetService.check_embedding_model_setting(dataset.tenant_id, - data.get('embedding_model_provider'), - data.get('embedding_model') - ) + if data.get("indexing_technique") == "high_quality": + DatasetService.check_embedding_model_setting( + dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get('permission'), data.get('partial_member_list') + current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset( - dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -234,15 +233,19 @@ def patch(self, dataset_id): result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get('partial_member_list') and data.get('permission') == 'partial_members': + if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get('partial_member_list') + tenant_id, dataset_id_str, data.get("partial_member_list") ) - else: + # clear partial member list when permission is only_me or all_team_members + elif ( + data.get("permission") == DatasetPermissionEnum.ONLY_ME + or data.get("permission") == DatasetPermissionEnum.ALL_TEAM + ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - result_data.update({'partial_member_list': partial_member_list}) + result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @@ -259,12 +262,13 @@ def delete(self, dataset_id): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() + class DatasetUseCheckApi(Resource): @setup_required @login_required @@ -273,10 +277,10 @@ def get(self, dataset_id): dataset_id_str = str(dataset_id) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) - return {'is_using': dataset_is_using}, 200 + return {"is_using": dataset_is_using}, 200 -class DatasetQueryApi(Resource): +class DatasetQueryApi(Resource): @setup_required @login_required @account_initialization_required @@ -291,51 +295,53 @@ def get(self, dataset_id): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) - dataset_queries, total = DatasetService.get_dataset_queries( - dataset_id=dataset.id, - page=page, - per_page=limit - ) + dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) response = { - 'data': marshal(dataset_queries, dataset_query_detail_fields), - 'has_more': len(dataset_queries) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(dataset_queries, dataset_query_detail_fields), + "has_more": len(dataset_queries) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 class DatasetIndexingEstimateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') + parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument( + "indexing_technique", + type=str, + required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + location="json", + ) + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args['info_list']['data_source_type'] == 'upload_file': - file_ids = args['info_list']['file_info_list']['file_ids'] - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(file_ids) - ).all() + if args["info_list"]["data_source_type"] == "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .all() + ) if file_details is None: raise NotFound("File not found.") @@ -343,55 +349,58 @@ def post(self): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=args['doc_form'] + datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'notion_import': - notion_info_list = args['info_list']['notion_info_list'] + elif args["info_list"]["data_source_type"] == "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'website_crawl': - website_info_list = args['info_list']['website_info_list'] - for url in website_info_list['urls']: + elif args["info_list"]["data_source_type"] == "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": website_info_list['provider'], - "job_id": website_info_list['job_id'], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], "url": url, "tenant_id": current_user.current_tenant_id, - "mode": 'crawl', - "only_main_content": website_info_list['only_main_content'] + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + args["dataset_id"], + args["indexing_technique"], + ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -401,7 +410,6 @@ def post(self): class DatasetRelatedAppListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -425,52 +433,52 @@ def get(self, dataset_id): if app_model: related_apps.append(app_model) - return { - 'data': related_apps, - 'total': len(related_apps) - }, 200 + return {"data": related_apps, "total": len(related_apps)}, 200 class DatasetIndexingStatusApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .all() + ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DatasetApiKeyApi(Resource): max_keys = 10 - token_prefix = 'dataset-' - resource_type = 'dataset' + token_prefix = "dataset-" + resource_type = "dataset" @setup_required @login_required @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - all() + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .all() + ) return {"items": keys} @setup_required @@ -482,15 +490,17 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -504,7 +514,7 @@ def post(self): class DatasetApiDeleteApi(Resource): - resource_type = 'dataset' + resource_type = "dataset" @setup_required @login_required @@ -516,18 +526,23 @@ def delete(self, api_key_id): if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, - ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + ApiToken.tenant_id == current_user.current_tenant_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DatasetApiBaseUrlApi(Resource): @@ -536,8 +551,10 @@ class DatasetApiBaseUrlApi(Resource): @account_initialization_required def get(self): return { - 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + "api_base_url": ( + dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") + ) + + "/v1" } @@ -548,15 +565,26 @@ class DatasetRetrievalSettingApi(Resource): def get(self): vector_type = dify_config.VECTOR_STORE match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: - return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.PGVECTOR + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + ): return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -572,15 +600,27 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + | VectorType.PGVECTO_RS + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + | VectorType.PGVECTOR + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -590,7 +630,6 @@ def get(self, vector_type): raise ValueError(f"Unsupported vector db type {vector_type}.") - class DatasetErrorDocs(Resource): @setup_required @login_required @@ -602,10 +641,7 @@ def get(self, dataset_id): raise NotFound("Dataset not found.") results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) - return { - 'data': [marshal(item, document_status_fields) for item in results], - 'total': len(results) - }, 200 + return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 class DatasetPermissionUserListApi(Resource): @@ -625,21 +661,21 @@ def get(self, dataset_id): partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) return { - 'data': partial_members_list, + "data": partial_members_list, }, 200 -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') -api.add_resource(DatasetUseCheckApi, '/datasets//use-check') -api.add_resource(DatasetQueryApi, '/datasets//queries') -api.add_resource(DatasetErrorDocs, '/datasets//error-docs') -api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') -api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') -api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') -api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') -api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') -api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') -api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') -api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') -api.add_resource(DatasetPermissionUserListApi, '/datasets//permission-part-users') +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DatasetUseCheckApi, "/datasets//use-check") +api.add_resource(DatasetQueryApi, "/datasets//queries") +api.add_resource(DatasetErrorDocs, "/datasets//error-docs") +api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") +api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") +api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") +api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") +api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") +api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") +api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") +api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") +api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index afe0ca7c69b2b7..6bc29a86435fa4 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -57,7 +57,7 @@ class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -67,17 +67,17 @@ def get_document(self, dataset_id: str, document_id: str) -> Document: document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') + raise Forbidden("No permission.") return document def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -87,7 +87,7 @@ def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") return documents @@ -99,11 +99,11 @@ class GetProcessRuleApi(Resource): def get(self): req_data = request.args - document_id = req_data.get('document_id') + document_id = req_data.get("document_id") # get default rules - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule document = Document.query.get_or_404(document_id) @@ -111,7 +111,7 @@ def get(self): dataset = DatasetService.get_dataset(document.dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -119,19 +119,18 @@ def get(self): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == document.dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return { - 'mode': mode, - 'rules': rules - } + return {"mode": mode, "rules": rules} class DatasetDocumentListApi(Resource): @@ -140,92 +139,99 @@ class DatasetDocumentListApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - sort = request.args.get('sort', default='-created_at', type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + sort = request.args.get("sort", default="-created_at", type=str) # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch = string_to_bool(request.args.get('fetch', default='false')) + fetch = string_to_bool(request.args.get("fetch", default="false")) except (ArgumentTypeError, ValueError, Exception) as e: fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) - if sort.startswith('-'): + if sort.startswith("-"): sort_logic = desc sort = sort[1:] else: sort_logic = asc - if sort == 'hit_count': - sub_query = db.select(DocumentSegment.document_id, - db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \ - .group_by(DocumentSegment.document_id) \ + if sort == "hit_count": + sub_query = ( + db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .group_by(DocumentSegment.document_id) .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0))) - elif sort == 'created_at': - query = query.order_by(sort_logic(Document.created_at)) + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) else: - query = query.order_by(desc(Document.created_at)) + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) else: data = marshal(documents, document_fields) response = { - 'data': data, - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response - documents_and_batch_fields = { - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String - } + documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} @setup_required @login_required @account_initialization_required @marshal_with(documents_and_batch_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: @@ -237,21 +243,22 @@ def post(self, dataset_id): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('data_source', type=dict, required=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, location='json') - parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("data_source", type=dict, required=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, location="json") + parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") # validate args DocumentService.document_create_args_validate(args) @@ -265,51 +272,53 @@ def post(self, dataset_id): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return { - 'documents': documents, - 'batch': batch - } + return {"documents": documents, "batch": batch} class DatasetInitApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(dataset_and_document_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, - nullable=False, location='json') - parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", + type=str, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + required=True, + nullable=False, + location="json", + ) + parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - if args['indexing_technique'] == 'high_quality': + if args["indexing_technique"] == "high_quality": try: model_manager = ModelManager() model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -318,9 +327,7 @@ def post(self): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, - document_data=args, - account=current_user + tenant_id=current_user.current_tenant_id, document_data=args, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -329,17 +336,12 @@ def post(self): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = { - 'dataset': dataset, - 'documents': documents, - 'batch': batch - } + response = {"dataset": dataset, "documents": documents, "batch": batch} return response class DocumentIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -348,50 +350,49 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: - raise NotFound('File not found.') + raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + [extract_setting], + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -401,7 +402,6 @@ def get(self, dataset_id, document_id): class DocumentBatchIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -409,13 +409,7 @@ def get(self, dataset_id, batch): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: return response data_process_rule = documents[0].dataset_process_rule @@ -423,82 +417,83 @@ def get(self, dataset_id, batch): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] info_list.append(file_id) # format document notion info - elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info: + elif ( + data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info + ): pages = [] - page = { - 'page_id': data_source_info['notion_page_id'], - 'type': data_source_info['type'] - } + page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]} pages.append(page) - notion_info = { - 'workspace_id': data_source_info['notion_workspace_id'], - 'pages': pages - } + notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} info_list.append(notion_info) - if document.data_source_type == 'upload_file': - file_id = data_source_info['upload_file_id'] - file_detail = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id == file_id - ).first() + if document.data_source_type == "upload_file": + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .first() + ) if file_detail is None: raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) - elif document.data_source_type == 'notion_import': + elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], - "tenant_id": current_user.current_tenant_id + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) - elif document.data_source_type == 'website_crawl': + elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], - "url": data_source_info['url'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], "tenant_id": current_user.current_tenant_id, - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -507,7 +502,6 @@ def get(self, dataset_id, batch): class DocumentBatchIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -517,24 +511,24 @@ def get(self, dataset_id, batch): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DocumentIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -543,25 +537,24 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query \ - .filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() - total_segments = DocumentSegment.query \ - .filter(DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" return marshal(document, document_status_fields) class DocumentDetailApi(DocumentResource): - METADATA_CHOICES = {'all', 'only', 'without'} + METADATA_CHOICES = {"all", "only", "without"} @setup_required @login_required @@ -571,77 +564,75 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - metadata = request.args.get('metadata', 'all') + metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: - raise InvalidMetadataError(f'Invalid metadata value: {metadata}') + raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == 'only': - response = { - 'id': document.id, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata - } - elif metadata == 'without': + if metadata == "only": + response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} + elif metadata == "without": process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, + "doc_language": document.doc_language, } else: process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, + "doc_language": document.doc_language, } return response, 200 @@ -662,7 +653,7 @@ def patch(self, dataset_id, document_id, action): if action == "pause": if document.indexing_status != "indexing": - raise InvalidActionError('Document not in indexing state.') + raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -671,7 +662,7 @@ def patch(self, dataset_id, document_id, action): elif action == "resume": if document.indexing_status not in ["paused", "error"]: - raise InvalidActionError('Document not in paused or error state.') + raise InvalidActionError("Document not in paused or error state.") document.paused_by = None document.paused_at = None @@ -680,7 +671,7 @@ def patch(self, dataset_id, document_id, action): else: raise InvalidActionError() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentDeleteApi(DocumentResource): @@ -701,9 +692,9 @@ def delete(self, dataset_id, document_id): try: DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentMetadataApi(DocumentResource): @@ -717,26 +708,26 @@ def put(self, dataset_id, document_id): req_data = request.get_json() - doc_type = req_data.get('doc_type') - doc_metadata = req_data.get('doc_metadata') + doc_type = req_data.get("doc_type") + doc_metadata = req_data.get("doc_metadata") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() if doc_type is None or doc_metadata is None: - raise ValueError('Both doc_type and doc_metadata must be provided.') + raise ValueError("Both doc_type and doc_metadata must be provided.") if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: - raise ValueError('Invalid doc_type.') + raise ValueError("Invalid doc_type.") if not isinstance(doc_metadata, dict): - raise ValueError('doc_metadata must be a dictionary.') + raise ValueError("doc_metadata must be a dictionary.") metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] document.doc_metadata = {} - if doc_type == 'others': + if doc_type == "others": document.doc_metadata = doc_metadata else: for key, value_type in metadata_schema.items(): @@ -748,14 +739,14 @@ def put(self, dataset_id, document_id): document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success', 'message': 'Document metadata updated.'}, 200 + return {"result": "success", "message": "Document metadata updated."}, 200 class DocumentStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -775,14 +766,14 @@ def patch(self, dataset_id, document_id, action): document = self.get_document(dataset_id, document_id) - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") if action == "enable": if document.enabled: - raise InvalidActionError('Document already enabled.') + raise InvalidActionError("Document already enabled.") document.enabled = True document.disabled_at = None @@ -795,13 +786,13 @@ def patch(self, dataset_id, document_id, action): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": - if not document.completed_at or document.indexing_status != 'completed': - raise InvalidActionError('Document is not completed.') + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError("Document is not completed.") if not document.enabled: - raise InvalidActionError('Document already disabled.') + raise InvalidActionError("Document already disabled.") document.enabled = False document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -814,11 +805,11 @@ def patch(self, dataset_id, document_id, action): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "archive": if document.archived: - raise InvalidActionError('Document already archived.') + raise InvalidActionError("Document already archived.") document.archived = True document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -832,10 +823,10 @@ def patch(self, dataset_id, document_id, action): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "un_archive": if not document.archived: - raise InvalidActionError('Document is not archived.') + raise InvalidActionError("Document is not archived.") document.archived = False document.archived_at = None @@ -848,13 +839,12 @@ def patch(self, dataset_id, document_id, action): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() class DocumentPauseApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -865,7 +855,7 @@ def patch(self, dataset_id, document_id): dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) @@ -881,9 +871,9 @@ def patch(self, dataset_id, document_id): # pause document DocumentService.pause_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot pause completed document.') + raise DocumentIndexingError("Cannot pause completed document.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRecoverApi(DocumentResource): @@ -896,7 +886,7 @@ def patch(self, dataset_id, document_id): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) # 404 if document not found @@ -910,9 +900,9 @@ def patch(self, dataset_id, document_id): # pause document DocumentService.recover_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Document is not in paused status.') + raise DocumentIndexingError("Document is not in paused status.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRetryApi(DocumentResource): @@ -923,15 +913,14 @@ def post(self, dataset_id): """retry document.""" parser = reqparse.RequestParser() - parser.add_argument('document_ids', type=list, required=True, nullable=False, - location='json') + parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: - raise NotFound('Dataset not found.') - for document_id in args['document_ids']: + raise NotFound("Dataset not found.") + for document_id in args["document_ids"]: try: document_id = str(document_id) @@ -946,7 +935,7 @@ def post(self, dataset_id): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == 'completed': + if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception as e: @@ -955,7 +944,7 @@ def post(self, dataset_id): # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRenameApi(DocumentResource): @@ -970,13 +959,13 @@ def post(self, dataset_id, document_id): dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_operator_permission(current_user, dataset) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - document = DocumentService.rename_document(dataset_id, document_id, args['name']) + document = DocumentService.rename_document(dataset_id, document_id, args["name"]) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") return document @@ -990,51 +979,43 @@ def get(self, dataset_id, document_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') - if document.data_source_type != 'website_crawl': - raise ValueError('Document is not a website document.') + raise Forbidden("No permission.") + if document.data_source_type != "website_crawl": + raise ValueError("Document is not a website document.") # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document DocumentService.sync_website_document(dataset_id, document) - return {'result': 'success'}, 200 - - -api.add_resource(GetProcessRuleApi, '/datasets/process-rule') -api.add_resource(DatasetDocumentListApi, - '/datasets//documents') -api.add_resource(DatasetInitApi, - '/datasets/init') -api.add_resource(DocumentIndexingEstimateApi, - '/datasets//documents//indexing-estimate') -api.add_resource(DocumentBatchIndexingEstimateApi, - '/datasets//batch//indexing-estimate') -api.add_resource(DocumentBatchIndexingStatusApi, - '/datasets//batch//indexing-status') -api.add_resource(DocumentIndexingStatusApi, - '/datasets//documents//indexing-status') -api.add_resource(DocumentDetailApi, - '/datasets//documents/') -api.add_resource(DocumentProcessingApi, - '/datasets//documents//processing/') -api.add_resource(DocumentDeleteApi, - '/datasets//documents/') -api.add_resource(DocumentMetadataApi, - '/datasets//documents//metadata') -api.add_resource(DocumentStatusApi, - '/datasets//documents//status/') -api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') -api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') -api.add_resource(DocumentRetryApi, '/datasets//retry') -api.add_resource(DocumentRenameApi, - '/datasets//documents//rename') - -api.add_resource(WebsiteDocumentSyncApi, '/datasets//documents//website-sync') + return {"result": "success"}, 200 + + +api.add_resource(GetProcessRuleApi, "/datasets/process-rule") +api.add_resource(DatasetDocumentListApi, "/datasets//documents") +api.add_resource(DatasetInitApi, "/datasets/init") +api.add_resource( + DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" +) +api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") +api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") +api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource( + DocumentProcessingApi, "/datasets//documents//processing/" +) +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") +api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") +api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") +api.add_resource(DocumentRetryApi, "/datasets//retry") +api.add_resource(DocumentRenameApi, "/datasets//documents//rename") + +api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3dcade61528707..2405649387e3fa 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -40,7 +40,7 @@ def get(self, dataset_id, document_id): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -50,37 +50,33 @@ def get(self, dataset_id, document_id): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument('last_id', type=str, default=None, location='args') - parser.add_argument('limit', type=int, default=20, location='args') - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('hit_count_gte', type=int, - default=None, location='args') - parser.add_argument('enabled', type=str, default='all', location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("last_id", type=str, default=None, location="args") + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("hit_count_gte", type=int, default=None, location="args") + parser.add_argument("enabled", type=str, default="all", location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - last_id = args['last_id'] - limit = min(args['limit'], 100) - status_list = args['status'] - hit_count_gte = args['hit_count_gte'] - keyword = args['keyword'] + last_id = args["last_id"] + limit = min(args["limit"], 100) + status_list = args["status"] + hit_count_gte = args["hit_count_gte"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if last_id is not None: last_segment = db.session.get(DocumentSegment, str(last_id)) if last_segment: - query = query.filter( - DocumentSegment.position > last_segment.position) + query = query.filter(DocumentSegment.position > last_segment.position) else: - return {'data': [], 'has_more': False, 'limit': limit}, 200 + return {"data": [], "has_more": False, "limit": limit}, 200 if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -89,12 +85,12 @@ def get(self, dataset_id, document_id): query = query.filter(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args['enabled'].lower() != 'all': - if args['enabled'].lower() == 'true': + if args["enabled"].lower() != "all": + if args["enabled"].lower() == "true": query = query.filter(DocumentSegment.enabled == True) - elif args['enabled'].lower() == 'false': + elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) total = query.count() @@ -106,11 +102,11 @@ def get(self, dataset_id, document_id): segments = segments[:-1] return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'has_more': has_more, - 'limit': limit, - 'total': total + "data": marshal(segments, segment_fields), + "doc_form": document.doc_form, + "has_more": has_more, + "limit": limit, + "total": total, }, 200 @@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, segment_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -134,7 +130,7 @@ def patch(self, dataset_id, segment_id, action): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -142,32 +138,32 @@ def patch(self, dataset_id, segment_id, action): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable or disable function is not allowed') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable or disable function is not allowed") - document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Segment is being indexed, please try again later") @@ -186,7 +182,7 @@ def patch(self, dataset_id, segment_id, action): enable_segment_to_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": if not segment.enabled: raise InvalidActionError("Segment is already disabled.") @@ -201,7 +197,7 @@ def patch(self, dataset_id, segment_id, action): disable_segment_from_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() @@ -210,36 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + raise NotFound("Document not found.") + if not current_user.is_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) try: @@ -248,37 +244,34 @@ def post(self, dataset_id, document_id): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.create_segment(args, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -286,22 +279,22 @@ def patch(self, dataset_id, document_id, segment_id): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -311,16 +304,13 @@ def patch(self, dataset_id, document_id, segment_id): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.update_segment(args, segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @login_required @@ -330,60 +320,59 @@ def delete(self, dataset_id, document_id, segment_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -391,51 +380,47 @@ def post(self, dataset_id, document_id): df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - if document.doc_form == 'qa_model': - data = {'content': row[0], 'answer': row[1]} + if document.doc_form == "qa_model": + data = {"content": row[0], "answer": row[1]} else: - data = {'content': row[0]} + data = {"content": row[0]} result.append(data) if len(result) == 0: raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_create_segment_to_index_task.delay( + str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return {'error': str(e)}, 500 - return { - 'job_id': job_id, - 'job_status': 'waiting' - }, 200 + return {"error": str(e)}, 500 + return {"job_id": job_id, "job_status": "waiting"}, 200 @setup_required @login_required @account_initialization_required def get(self, job_id): job_id = str(job_id) - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") - return { - 'job_id': job_id, - 'job_status': cache_result.decode() - }, 200 + return {"job_id": job_id, "job_status": cache_result.decode()}, 200 -api.add_resource(DatasetDocumentSegmentListApi, - '/datasets//documents//segments') -api.add_resource(DatasetDocumentSegmentApi, - '/datasets//segments//') -api.add_resource(DatasetDocumentSegmentAddApi, - '/datasets//documents//segment') -api.add_resource(DatasetDocumentSegmentUpdateApi, - '/datasets//documents//segments/') -api.add_resource(DatasetDocumentSegmentBatchImportApi, - '/datasets//documents//segments/batch_import', - '/datasets/batch_import_status/') +api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") +api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") +api.add_resource( + DatasetDocumentSegmentUpdateApi, + "/datasets//documents//segments/", +) +api.add_resource( + DatasetDocumentSegmentBatchImportApi, + "/datasets//documents//segments/batch_import", + "/datasets/batch_import_status/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 9270b610c28e8c..6a7a3971a8b33f 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -2,90 +2,90 @@ class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class WebsiteCrawlError(BaseHTTPException): - error_code = 'crawl_failed' + error_code = "crawl_failed" description = "{message}" code = 500 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 class IndexingEstimateError(BaseHTTPException): - error_code = 'indexing_estimate_error' + error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 3b2083bcc3351c..846aa70e86815f 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -21,7 +21,6 @@ class FileApi(Resource): - @setup_required @login_required @account_initialization_required @@ -31,23 +30,22 @@ def get(self): batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT return { - 'file_size_limit': file_size_limit, - 'batch_count_limit': batch_count_limit, - 'image_file_size_limit': image_file_size_limit + "file_size_limit": file_size_limit, + "batch_count_limit": batch_count_limit, + "image_file_size_limit": image_file_size_limit, }, 200 @setup_required @login_required @account_initialization_required @marshal_with(file_fields) - @cloud_edition_billing_resource_check(resource='documents') + @cloud_edition_billing_resource_check("documents") def post(self): - # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -69,7 +67,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) text = FileService.get_file_preview(file_id) - return {'content': text} + return {"content": text} class FileSupportTypeApi(Resource): @@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource): @account_initialization_required def get(self): etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS - return {'allowed_extensions': allowed_extensions} + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + return {"allowed_extensions": allowed_extensions} -api.add_resource(FileApi, '/files/upload') -api.add_resource(FilePreviewApi, '/files//preview') -api.add_resource(FileSupportTypeApi, '/files/support-type') +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 8771bf909ed650..0b4a7be986e167 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -29,7 +29,6 @@ class HitTestingApi(Resource): - @setup_required @login_required @account_initialization_required @@ -46,8 +45,8 @@ def post(self, dataset_id): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('query', type=str, location='json') - parser.add_argument('retrieval_model', type=dict, required=False, location='json') + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -55,13 +54,13 @@ def post(self, dataset_id): try: response = HitTestingService.retrieve( dataset=dataset, - query=args['query'], + query=args["query"], account=current_user, - retrieval_model=args['retrieval_model'], - limit=10 + retrieval_model=args["retrieval_model"], + limit=10, ) - return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() except ProviderTokenNotInitError as ex: @@ -73,7 +72,8 @@ def post(self, dataset_id): except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -83,4 +83,4 @@ def post(self, dataset_id): raise InternalServerError(str(e)) -api.add_resource(HitTestingApi, '/datasets//hit-testing') +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bbd91256f1c29c..cb54f1aacbc2ab 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -9,16 +9,14 @@ class WebsiteCrawlApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], - required=True, nullable=True, location='json') - parser.add_argument('url', type=str, required=True, nullable=True, location='json') - parser.add_argument('options', type=dict, required=True, nullable=True, location='json') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json") + parser.add_argument("url", type=str, required=True, nullable=True, location="json") + parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() WebsiteService.document_create_args_validate(args) # crawl url @@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource): @account_initialization_required def get(self, job_id: str): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args") args = parser.parse_args() # get crawl status try: - result = WebsiteService.get_crawl_status(job_id, args['provider']) + result = WebsiteService.get_crawl_status(job_id, args["provider"]) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 -api.add_resource(WebsiteCrawlApi, '/website/crawl') -api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/') +api.add_resource(WebsiteCrawlApi, "/website/crawl") +api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 888dad83ccda84..1c70ea6c591ecf 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -2,35 +2,41 @@ class AlreadySetupError(BaseHTTPException): - error_code = 'already_setup' + error_code = "already_setup" description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." code = 403 class NotSetupError(BaseHTTPException): - error_code = 'not_setup' - description = "Dify has not been initialized and installed yet. " \ - "Please proceed with the initialization and installation process first." + error_code = "not_setup" + description = ( + "Dify has not been initialized and installed yet. " + "Please proceed with the initialization and installation process first." + ) code = 401 + class NotInitValidateError(BaseHTTPException): - error_code = 'not_init_validated' - description = "Init validation has not been completed yet. " \ - "Please proceed with the init validation process first." + error_code = "not_init_validated" + description = ( + "Init validation has not been completed yet. " "Please proceed with the init validation process first." + ) code = 401 + class InitValidateFailedError(BaseHTTPException): - error_code = 'init_validate_failed' + error_code = "init_validate_failed" description = "Init validation failed. Please check the password and try again." code = 401 + class AccountNotLinkTenantError(BaseHTTPException): - error_code = 'account_not_link_tenant' + error_code = "account_not_link_tenant" description = "Account not link tenant." code = 403 class AlreadyActivateError(BaseHTTPException): - error_code = 'already_activate' + error_code = "already_activate" description = "Auth Token is invalid or account already activated, please check again." code = 403 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 27cc83042a7822..71cb060ecc188e 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=None - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -76,30 +72,31 @@ def post(self, installed_app): app_model = installed_app.app try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - voice=voice, - text=text - ) + response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -127,7 +124,7 @@ def post(self, installed_app): raise InternalServerError() -api.add_resource(ChatAudioApi, '/installed-apps//audio-to-text', endpoint='installed_app_audio') -api.add_resource(ChatTextApi, '/installed-apps//text-to-audio', endpoint='installed_app_text') +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") # api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', # endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 869b56e13bf939..c039e8bca52a29 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -30,33 +30,28 @@ # define completion api for user class CompletionApi(InstalledAppResource): - def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) return helper.compact_generate_response(response) @@ -85,12 +80,12 @@ def post(self, installed_app): class CompletionStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(InstalledAppResource): @@ -101,25 +96,21 @@ def post(self, installed_app): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - args['auto_generate_name'] = False + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -154,10 +145,22 @@ def post(self, installed_app, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') -api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') -api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index ea0fa4e17e9d44..2918024b64cc11 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -16,7 +16,6 @@ class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app @@ -25,21 +24,21 @@ def get(self, installed_app): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=current_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.EXPLORE, pinned=pinned, ) @@ -65,7 +64,6 @@ def delete(self, installed_app, c_id): class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app @@ -76,24 +74,19 @@ def post(self, installed_app, c_id): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: return ConversationService.rename( - app_model, - conversation_id, - current_user, - args['name'], - args['auto_generate'] + app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(InstalledAppResource): - def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -123,8 +116,26 @@ def patch(self, installed_app, c_id): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/installed-apps//conversations//name', endpoint='installed_app_conversation_rename') -api.add_resource(ConversationListApi, '/installed-apps//conversations', endpoint='installed_app_conversations') -api.add_resource(ConversationApi, '/installed-apps//conversations/', endpoint='installed_app_conversation') -api.add_resource(ConversationPinApi, '/installed-apps//conversations//pin', endpoint='installed_app_conversation_pin') -api.add_resource(ConversationUnPinApi, '/installed-apps//conversations//unpin', endpoint='installed_app_conversation_unpin') +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 9c3216ecc8c178..18221b7797cdb0 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -2,24 +2,24 @@ class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Not Completion App" code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "App mode is invalid." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Only support workflow app." code = 400 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ec7bbed3074ad6..3f1e64a2478b4b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): current_tenant_id = current_user.current_tenant_id - installed_apps = db.session.query(InstalledApp).filter( - InstalledApp.tenant_id == current_tenant_id - ).all() + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_apps = [ { - 'id': installed_app.id, - 'app': installed_app.app, - 'app_owner_tenant_id': installed_app.app_owner_tenant_id, - 'is_pinned': installed_app.is_pinned, - 'last_used_at': installed_app.last_used_at, - 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id + "id": installed_app.id, + "app": installed_app.app, + "app_owner_tenant_id": installed_app.app_owner_tenant_id, + "is_pinned": installed_app.is_pinned, + "last_used_at": installed_app.last_used_at, + "editable": current_user.role in ["owner", "admin"], + "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps + if installed_app.app is not None ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], - app['last_used_at'] is None, - -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) + installed_apps.sort( + key=lambda app: ( + -app["is_pinned"], + app["last_used_at"] is None, + -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, + ) + ) - return {'installed_apps': installed_apps} + return {"installed_apps": installed_apps} @login_required @account_initialization_required - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: - raise NotFound('App not found') + raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter( - App.id == args['app_id'] - ).first() + app = db.session.query(App).filter(App.id == args["app_id"]).first() if app is None: - raise NotFound('App not found') + raise NotFound("App not found") if not app.is_public: - raise Forbidden('You can\'t install a non-public app') + raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter(and_( - InstalledApp.app_id == args['app_id'], - InstalledApp.tenant_id == current_tenant_id - )).first() + installed_app = InstalledApp.query.filter( + and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) + ).first() if installed_app is None: # todo: position recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args['app_id'], + app_id=args["app_id"], tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) + last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() - return {'message': 'App installed successfully'} + return {"message": "App installed successfully"} class InstalledAppApi(InstalledAppResource): @@ -94,30 +94,31 @@ class InstalledAppApi(InstalledAppResource): update and delete an installed app use InstalledAppResource to apply default decorators and get installed_app """ + def delete(self, installed_app): if installed_app.app_owner_tenant_id == current_user.current_tenant_id: - raise BadRequest('You can\'t uninstall an app owned by the current tenant') + raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) db.session.commit() - return {'result': 'success', 'message': 'App uninstalled successfully'} + return {"result": "success", "message": "App uninstalled successfully"} def patch(self, installed_app): parser = reqparse.RequestParser() - parser.add_argument('is_pinned', type=inputs.boolean) + parser.add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() commit_args = False - if 'is_pinned' in args: - installed_app.is_pinned = args['is_pinned'] + if "is_pinned" in args: + installed_app.is_pinned = args["is_pinned"] commit_args = True if commit_args: db.session.commit() - return {'result': 'success', 'message': 'App info updated successfully'} + return {"result": "success", "message": "App info updated successfully"} -api.add_resource(InstalledAppsListApi, '/installed-apps') -api.add_resource(InstalledAppApi, '/installed-apps/') +api.add_resource(InstalledAppsListApi, "/installed-apps") +api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3523a869003a32..f5eb18517258d5 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -44,19 +44,21 @@ def get(self, installed_app): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, current_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") + class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app @@ -64,30 +66,32 @@ def post(self, installed_app, message_id): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args['rating']) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -95,7 +99,7 @@ def get(self, installed_app, message_id): user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) except MessageNotExistsError: @@ -128,10 +132,7 @@ def get(self, installed_app, message_id): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) except MessageNotExistsError: raise NotFound("Message not found") @@ -151,10 +152,22 @@ def get(self, installed_app, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') -api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') -api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0a168d6306c83e..ad55b040436d47 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from configs import dify_config @@ -11,33 +10,32 @@ class AppParameterApi(InstalledAppResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -56,30 +54,35 @@ def get(self, installed_app: InstalledApp): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -90,6 +93,7 @@ def get(self, installed_app: InstalledApp): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/installed-apps//parameters', - endpoint='installed_app_parameters') -api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') +api.add_resource( + AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" +) +api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6e10e2ec92671e..5daaa1e7c38ba8 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -8,28 +8,28 @@ from services.recommended_app_service import RecommendedAppService app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon": fields.String, + "icon_background": fields.String, } recommended_app_fields = { - 'app': fields.Nested(app_fields, attribute='app'), - 'app_id': fields.String, - 'description': fields.String(attribute='description'), - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'category': fields.String, - 'position': fields.Integer, - 'is_listed': fields.Boolean + "app": fields.Nested(app_fields, attribute="app"), + "app_id": fields.String, + "description": fields.String(attribute="description"), + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "category": fields.String, + "position": fields.Integer, + "is_listed": fields.Boolean, } recommended_app_list_fields = { - 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), - 'categories': fields.List(fields.String) + "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "categories": fields.List(fields.String), } @@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource): def get(self): # language args parser = reqparse.RequestParser() - parser.add_argument('language', type=str, location='args') + parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get('language') and args.get('language') in languages: - language_prefix = args.get('language') + if args.get("language") and args.get("language") in languages: + language_prefix = args.get("language") elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: @@ -61,5 +61,5 @@ def get(self, app_id): return RecommendedAppService.get_recommend_app_detail(app_id) -api.add_resource(RecommendedAppListApi, '/explore/apps') -api.add_resource(RecommendedAppApi, '/explore/apps/') +api.add_resource(RecommendedAppListApi, "/explore/apps") +api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index cf86b2fee1cead..a7ccf737a83973 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,56 +11,54 @@ from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, current_user, args['message_id']) + SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(InstalledAppResource): @@ -69,13 +67,21 @@ def delete(self, installed_app, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, current_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/installed-apps//saved-messages', endpoint='installed_app_saved_messages') -api.add_resource(SavedMessageApi, '/installed-apps//saved-messages/', endpoint='installed_app_saved_message') +api.add_resource( + SavedMessageListApi, + "/installed-apps//saved-messages", + endpoint="installed_app_saved_messages", +) +api.add_resource( + SavedMessageApi, + "/installed-apps//saved-messages/", + endpoint="installed_app_saved_message", +) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7c5e211d4788bc..45f99b1db9fa9e 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -35,17 +35,13 @@ def post(self, installed_app: InstalledApp): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -76,10 +72,10 @@ def post(self, installed_app: InstalledApp, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps//workflows/run') -api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps//workflows/tasks//stop') +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" +) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 84890f1b46471a..3c9317847b0ab3 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -14,29 +14,33 @@ def installed_app_required(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - if not kwargs.get('installed_app_id'): - raise ValueError('missing installed_app_id in path parameters') + if not kwargs.get("installed_app_id"): + raise ValueError("missing installed_app_id in path parameters") - installed_app_id = kwargs.get('installed_app_id') + installed_app_id = kwargs.get("installed_app_id") installed_app_id = str(installed_app_id) - del kwargs['installed_app_id'] + del kwargs["installed_app_id"] - installed_app = db.session.query(InstalledApp).filter( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_user.current_tenant_id - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter( + InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id + ) + .first() + ) if installed_app is None: - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") if not installed_app.app: db.session.delete(installed_app) db.session.commit() - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") return view(installed_app, *args, **kwargs) + return decorated if view: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index fa73c44c220a06..5d6a8bf1524927 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,6 +1,7 @@ from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse +from constants import HIDDEN_VALUE from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -12,23 +13,18 @@ class CodeBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('module', type=str, required=True, location='args') + parser.add_argument("module", type=str, required=True, location="args") args = parser.parse_args() - return { - 'module': args['module'], - 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) - } + return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} class APIBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -43,23 +39,22 @@ def get(self): @marshal_with(api_based_extension_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() extension_data = APIBasedExtension( tenant_id=current_user.current_tenant_id, - name=args['name'], - api_endpoint=args['api_endpoint'], - api_key=args['api_key'] + name=args["name"], + api_endpoint=args["api_endpoint"], + api_key=args["api_key"], ) return APIBasedExtensionService.save(extension_data) class APIBasedExtensionDetailAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -81,16 +76,16 @@ def post(self, id): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data_from_db.name = args['name'] - extension_data_from_db.api_endpoint = args['api_endpoint'] + extension_data_from_db.name = args["name"] + extension_data_from_db.api_endpoint = args["api_endpoint"] - if args['api_key'] != '[__HIDDEN__]': - extension_data_from_db.api_key = args['api_key'] + if args["api_key"] != HIDDEN_VALUE: + extension_data_from_db.api_key = args["api_key"] return APIBasedExtensionService.save(extension_data_from_db) @@ -105,10 +100,10 @@ def delete(self, id): APIBasedExtensionService.delete(extension_data_from_db) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') +api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") -api.add_resource(APIBasedExtensionAPI, '/api-based-extension') -api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') +api.add_resource(APIBasedExtensionAPI, "/api-based-extension") +api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 8475cd848822a7..f0482f749d9127 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -10,7 +10,6 @@ class FeatureApi(Resource): - @setup_required @login_required @account_initialization_required @@ -24,5 +23,5 @@ def get(self): return FeatureService.get_system_features().model_dump() -api.add_resource(FeatureApi, '/features') -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(FeatureApi, "/features") +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 6feb1003a975f6..7d3ae677ee4f41 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -14,12 +14,11 @@ class InitValidateAPI(Resource): - def get(self): init_status = get_init_validate_status() if init_status: - return { 'status': 'finished' } - return {'status': 'not_started' } + return {"status": "finished"} + return {"status": "not_started"} @only_edition_self_hosted def post(self): @@ -29,22 +28,23 @@ def post(self): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument('password', type=str_len(30), - required=True, location='json') - input_password = parser.parse_args()['password'] + parser.add_argument("password", type=str_len(30), required=True, location="json") + input_password = parser.parse_args()["password"] - if input_password != os.environ.get('INIT_PASSWORD'): - session['is_init_validated'] = False + if input_password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False raise InitValidateFailedError() - - session['is_init_validated'] = True - return {'result': 'success'}, 201 + + session["is_init_validated"] = True + return {"result": "success"}, 201 + def get_init_validate_status(): - if dify_config.EDITION == 'SELF_HOSTED': - if os.environ.get('INIT_PASSWORD'): - return session.get('is_init_validated') or DifySetup.query.first() - + if dify_config.EDITION == "SELF_HOSTED": + if os.environ.get("INIT_PASSWORD"): + return session.get("is_init_validated") or DifySetup.query.first() + return True -api.add_resource(InitValidateAPI, '/init') + +api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 7664ba8c165db8..cd28cc946ee288 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -4,14 +4,11 @@ class PingApi(Resource): - def get(self): """ For connection health check """ - return { - "result": "pong" - } + return {"result": "pong"} -api.add_resource(PingApi, '/ping') +api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index ef7cc6bc03ce3a..827695e00f45f9 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -16,17 +16,13 @@ class SetupApi(Resource): - def get(self): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() if setup_status: - return { - 'step': 'finished', - 'setup_at': setup_status.setup_at.isoformat() - } - return {'step': 'not_started'} - return {'step': 'finished'} + return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + return {"step": "not_started"} + return {"step": "finished"} @only_edition_self_hosted def post(self): @@ -38,28 +34,22 @@ def post(self): tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - + if not get_init_validate_status(): raise NotInitValidateError() parser = reqparse.RequestParser() - parser.add_argument('email', type=email, - required=True, location='json') - parser.add_argument('name', type=str_len( - 30), required=True, location='json') - parser.add_argument('password', type=valid_password, - required=True, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() # setup RegisterService.setup( - email=args['email'], - name=args['name'], - password=args['password'], - ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) ) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 def setup_required(view): @@ -68,7 +58,7 @@ def decorated(*args, **kwargs): # check setup if not get_init_validate_status(): raise NotInitValidateError() - + elif not get_setup_status(): raise NotSetupError() @@ -78,9 +68,10 @@ def decorated(*args, **kwargs): def get_setup_status(): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": return DifySetup.query.first() else: return True -api.add_resource(SetupApi, '/setup') + +api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 004afaa531a086..7293aeeb34cef9 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -14,19 +14,18 @@ def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 50 characters.') + raise ValueError("Name must be between 1 to 50 characters.") return name class TagListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get('type', type=str) - keyword = request.args.get('keyword', default=None, type=str) + tag_type = request.args.get("type", type=str) + keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) return tags, 200 @@ -40,28 +39,21 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() tag = TagService.save_tags(args) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': 0 - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 class TagUpdateDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -72,20 +64,15 @@ def patch(self, tag_id): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) args = parser.parse_args() tag = TagService.update_tags(args, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': binding_count - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} return response, 200 @@ -104,7 +91,6 @@ def delete(self, tag_id): class TagBindingCreateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -114,14 +100,15 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json', - help='Tag IDs is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, location='json', - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.save_tag_binding(args) @@ -129,7 +116,6 @@ def post(self): class TagBindingDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -139,21 +125,18 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_id', type=str, nullable=False, required=True, - help='Tag ID is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.delete_tag_binding(args) return 200 -api.add_resource(TagListApi, '/tags') -api.add_resource(TagUpdateDeleteApi, '/tags/') -api.add_resource(TagBindingCreateApi, '/tag-bindings/create') -api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove') +api.add_resource(TagListApi, "/tags") +api.add_resource(TagUpdateDeleteApi, "/tags/") +api.add_resource(TagBindingCreateApi, "/tag-bindings/create") +api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 1fcf4bdc00e5b1..76adbfe6a9d6cf 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ - import json import logging @@ -11,42 +10,39 @@ class VersionApi(Resource): - def get(self): parser = reqparse.RequestParser() - parser.add_argument('current_version', type=str, required=True, location='args') + parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL result = { - 'version': dify_config.CURRENT_VERSION, - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False, - 'features': { - 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, - 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED - } + "version": dify_config.CURRENT_VERSION, + "release_date": "", + "release_notes": "", + "can_auto_update": False, + "features": { + "can_replace_logo": dify_config.CAN_REPLACE_LOGO, + "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, + }, } if not check_update_url: return result try: - response = requests.get(check_update_url, { - 'current_version': args.get('current_version') - }) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - result['version'] = args.get('current_version') + result["version"] = args.get("current_version") return result content = json.loads(response.content) - result['version'] = content['version'] - result['release_date'] = content['releaseDate'] - result['release_notes'] = content['releaseNotes'] - result['can_auto_update'] = content['canAutoUpdate'] + result["version"] = content["version"] + result["release_date"] = content["releaseDate"] + result["release_notes"] = content["releaseNotes"] + result["can_auto_update"] = content["canAutoUpdate"] return result -api.add_resource(VersionApi, '/version') +api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1056d5eb62faf2..dec426128f5135 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -26,52 +26,53 @@ class AccountInitApi(Resource): - @setup_required @login_required def post(self): account = current_user - if account.status == 'active': + if account.status == "active": raise AccountAlreadyInitedError() parser = reqparse.RequestParser() - if dify_config.EDITION == 'CLOUD': - parser.add_argument('invitation_code', type=str, location='json') + if dify_config.EDITION == "CLOUD": + parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') - parser.add_argument('timezone', type=timezone, - required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") + parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() - if dify_config.EDITION == 'CLOUD': - if not args['invitation_code']: - raise ValueError('invitation_code is required') + if dify_config.EDITION == "CLOUD": + if not args["invitation_code"]: + raise ValueError("invitation_code is required") # check invitation code - invitation_code = db.session.query(InvitationCode).filter( - InvitationCode.code == args['invitation_code'], - InvitationCode.status == 'unused', - ).first() + invitation_code = ( + db.session.query(InvitationCode) + .filter( + InvitationCode.code == args["invitation_code"], + InvitationCode.status == "unused", + ) + .first() + ) if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = 'used' + invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' - account.status = 'active' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" + account.status = "active" account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class AccountProfileApi(Resource): @@ -90,15 +91,14 @@ class AccountNameApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length - if len(args['name']) < 3 or len(args['name']) > 30: - raise ValueError( - "Account name must be between 3 and 30 characters.") + if len(args["name"]) < 3 or len(args["name"]) > 30: + raise ValueError("Account name must be between 3 and 30 characters.") - updated_account = AccountService.update_account(current_user, name=args['name']) + updated_account = AccountService.update_account(current_user, name=args["name"]) return updated_account @@ -110,10 +110,10 @@ class AccountAvatarApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('avatar', type=str, required=True, location='json') + parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, avatar=args['avatar']) + updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) return updated_account @@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) + updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) return updated_account @@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('interface_theme', type=str, choices=[ - 'light', 'dark'], required=True, location='json') + parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) + updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) return updated_account @@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('timezone', type=str, - required=True, location='json') + parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai - if args['timezone'] not in pytz.all_timezones: + if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") - updated_account = AccountService.update_account(current_user, timezone=args['timezone']) + updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) return updated_account @@ -177,20 +174,16 @@ class AccountPasswordApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('password', type=str, - required=False, location='json') - parser.add_argument('new_password', type=str, - required=True, location='json') - parser.add_argument('repeat_new_password', type=str, - required=True, location='json') + parser.add_argument("password", type=str, required=False, location="json") + parser.add_argument("new_password", type=str, required=True, location="json") + parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() - if args['new_password'] != args['repeat_new_password']: + if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() try: - AccountService.update_account_password( - current_user, args['password'], args['new_password']) + AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -199,14 +192,14 @@ def post(self): class AccountIntegrateApi(Resource): integrate_fields = { - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'link': fields.String + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), + "data": fields.List(fields.Nested(integrate_fields)), } @setup_required @@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter( - AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] @@ -227,36 +219,38 @@ def get(self): for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'link': None - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "link": None, + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'is_bound': False, - 'link': f'{base_url}{oauth_base_path}/{provider}' - }) - - return {'data': integrate_data} - + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "is_bound": False, + "link": f"{base_url}{oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data} # Register API resources -api.add_resource(AccountInitApi, '/account/init') -api.add_resource(AccountProfileApi, '/account/profile') -api.add_resource(AccountNameApi, '/account/name') -api.add_resource(AccountAvatarApi, '/account/avatar') -api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') -api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') -api.add_resource(AccountTimezoneApi, '/account/timezone') -api.add_resource(AccountPasswordApi, '/account/password') -api.add_resource(AccountIntegrateApi, '/account/integrates') +api.add_resource(AccountInitApi, "/account/init") +api.add_resource(AccountProfileApi, "/account/profile") +api.add_resource(AccountNameApi, "/account/name") +api.add_resource(AccountAvatarApi, "/account/avatar") +api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") +api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") +api.add_resource(AccountTimezoneApi, "/account/timezone") +api.add_resource(AccountPasswordApi, "/account/password") +api.add_resource(AccountIntegrateApi, "/account/integrates") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 99f55835bc6b97..9e13c7b9241ff1 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -2,36 +2,36 @@ class RepeatPasswordNotMatchError(BaseHTTPException): - error_code = 'repeat_password_not_match' + error_code = "repeat_password_not_match" description = "New password and repeat password does not match." code = 400 class CurrentPasswordIncorrectError(BaseHTTPException): - error_code = 'current_password_incorrect' + error_code = "current_password_incorrect" description = "Current password is incorrect." code = 400 class ProviderRequestFailedError(BaseHTTPException): - error_code = 'provider_request_failed' + error_code = "provider_request_failed" description = None code = 400 class InvalidInvitationCodeError(BaseHTTPException): - error_code = 'invalid_invitation_code' + error_code = "invalid_invitation_code" description = "Invalid invitation code." code = 400 class AccountAlreadyInitedError(BaseHTTPException): - error_code = 'account_already_inited' + error_code = "account_already_inited" description = "The account has been initialized. Please refresh the page." code = 400 class AccountNotInitializedError(BaseHTTPException): - error_code = 'account_not_initialized' + error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 50514e39f6aad1..771a866624744d 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -22,10 +22,16 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing credentials @@ -38,18 +44,18 @@ def post(self, provider: str): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response @@ -65,10 +71,16 @@ def post(self, provider: str, config_id: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing config credentials @@ -81,26 +93,30 @@ def post(self, provider: str, config_id: str): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'], + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], config_id=config_id, ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response # Load Balancing Config -api.add_resource(LoadBalancingCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') - -api.add_resource(LoadBalancingConfigCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') +api.add_resource( + LoadBalancingCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", +) + +api.add_resource( + LoadBalancingConfigCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", +) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 34e9da384106a7..3e87bebf591505 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -23,7 +23,7 @@ class MemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 class MemberInviteEmailApi(Resource): @@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('members') + @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument('emails', type=str, required=True, location='json', action='append') - parser.add_argument('role', type=str, required=True, default='admin', location='json') - parser.add_argument('language', type=str, required=False, location='json') + parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("role", type=str, required=True, default="admin", location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - invitee_emails = args['emails'] - invitee_role = args['role'] - interface_language = args['language'] + invitee_emails = args["emails"] + invitee_role = args["role"] + interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL for invitee_email in invitee_emails: try: - token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' - }) + token = RegisterService.invite_new_member( + inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + ) + invitation_results.append( + { + "status": "success", + "email": invitee_email, + "url": f"{console_web_url}/activate?email={invitee_email}&token={token}", + } + ) except AccountAlreadyInTenantError: - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/signin' - }) + invitation_results.append( + {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + ) break except Exception as e: - invitation_results.append({ - 'status': 'failed', - 'email': invitee_email, - 'message': str(e) - }) + invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) return { - 'result': 'success', - 'invitation_results': invitation_results, + "result": "success", + "invitation_results": invitation_results, }, 201 @@ -91,15 +89,15 @@ def delete(self, member_id): try: TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) except services.errors.account.CannotOperateSelfError as e: - return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: - return {'code': 'forbidden', 'message': str(e)}, 403 + return {"code": "forbidden", "message": str(e)}, 403 except services.errors.account.MemberNotInTenantError as e: - return {'code': 'member-not-found', 'message': str(e)}, 404 + return {"code": "member-not-found", "message": str(e)}, 404 except Exception as e: raise ValueError(str(e)) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class MemberUpdateRoleApi(Resource): @@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource): @account_initialization_required def put(self, member_id): parser = reqparse.RequestParser() - parser.add_argument('role', type=str, required=True, location='json') + parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() - new_role = args['role'] + new_role = args["role"] if not TenantAccountRole.is_valid_role(new_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) if not member: @@ -128,7 +126,7 @@ def put(self, member_id): # todo: 403 - return {'result': 'success'} + return {"result": "success"} class DatasetOperatorMemberListApi(Resource): @@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 -api.add_resource(MemberListApi, '/workspaces/current/members') -api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') -api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/') -api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members//update-role') -api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators') +api.add_resource(MemberListApi, "/workspaces/current/members") +api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") +api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") +api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") +api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index c888159f839719..8c384202260a00 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -17,7 +17,6 @@ class ModelProviderListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -25,21 +24,23 @@ def get(self): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=False, nullable=True, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=False, + nullable=True, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() - provider_list = model_provider_service.get_provider_list( - tenant_id=tenant_id, - model_type=args.get('model_type') - ) + provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) return jsonable_encoder({"data": provider_list}) class ModelProviderCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -47,25 +48,18 @@ def get(self, provider: str): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials( - tenant_id=tenant_id, - provider=provider - ) + credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) - return { - "credentials": credentials - } + return {"credentials": credentials} class ModelProviderValidateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id @@ -77,24 +71,21 @@ def post(self, provider: str): try: model_provider_service.provider_credentials_validate( - tenant_id=tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderApi(Resource): - @setup_required @login_required @account_initialization_required @@ -103,21 +94,19 @@ def post(self, provider: str): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() try: model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 @setup_required @login_required @@ -127,12 +116,9 @@ def delete(self, provider: str): raise Forbidden() model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderIconApi(Resource): @@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource): def get(self, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( - provider=provider, - icon_type=icon_type, - lang=lang + provider=provider, icon_type=icon_type, lang=lang ) return send_file(io.BytesIO(icon), mimetype=mimetype) class PreferredProviderTypeUpdateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,18 +149,22 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, - choices=['system', 'custom'], location='json') + parser.add_argument( + "preferred_provider_type", + type=str, + required=True, + nullable=False, + choices=["system", "custom"], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( - tenant_id=tenant_id, - provider=provider, - preferred_provider_type=args['preferred_provider_type'] + tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderPaymentCheckoutUrlApi(Resource): @@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if provider != 'anthropic': - raise ValueError(f'provider name {provider} is invalid') + if provider != "anthropic": + raise ValueError(f"provider name {provider} is invalid") BillingService.is_tenant_owner_or_admin(current_user) - data = BillingService.get_model_provider_payment_link(provider_name=provider, - tenant_id=current_user.current_tenant_id, - account_id=current_user.id, - prefilled_email=current_user.email) + data = BillingService.get_model_provider_payment_link( + provider_name=provider, + tenant_id=current_user.current_tenant_id, + account_id=current_user.id, + prefilled_email=current_user.email, + ) return data @@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource): @account_initialization_required def post(self, provider: str): model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_submit( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) return result @@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=False, nullable=True, location='args') + parser.add_argument("token", type=str, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() result = model_provider_service.free_quota_qualification_verify( - tenant_id=current_user.current_tenant_id, - provider=provider, - token=args['token'] + tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] ) return result -api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') - -api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers//credentials') -api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//credentials/validate') -api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/') -api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers//' - '/') - -api.add_resource(PreferredProviderTypeUpdateApi, - '/workspaces/current/model-providers//preferred-provider-type') -api.add_resource(ModelProviderPaymentCheckoutUrlApi, - '/workspaces/current/model-providers//checkout-url') -api.add_resource(ModelProviderFreeQuotaSubmitApi, - '/workspaces/current/model-providers//free-quota-submit') -api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, - '/workspaces/current/model-providers//free-quota-qualification-verify') +api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") + +api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") +api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") +api.add_resource( + ModelProviderIconApi, "/workspaces/current/model-providers//" "/" +) + +api.add_resource( + PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" +) +api.add_resource( + ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url" +) +api.add_resource( + ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers//free-quota-submit" +) +api.add_resource( + ModelProviderFreeQuotaQualificationVerifyApi, + "/workspaces/current/model-providers//free-quota-qualification-verify", +) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 69f2253e97ea40..dc88f6b8127fbc 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -16,27 +16,29 @@ class DefaultModelApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( - tenant_id=tenant_id, - model_type=args['model_type'] + tenant_id=tenant_id, model_type=args["model_type"] ) - return jsonable_encoder({ - "data": default_model_entity - }) + return jsonable_encoder({"data": default_model_entity}) @setup_required @login_required @@ -44,40 +46,39 @@ def get(self): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') + parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - model_settings = args['model_settings'] + model_settings = args["model_settings"] for model_setting in model_settings: - if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: - raise ValueError('invalid model type') + if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: + raise ValueError("invalid model type") - if 'provider' not in model_setting: + if "provider" not in model_setting: continue - if 'model' not in model_setting: - raise ValueError('invalid model') + if "model" not in model_setting: + raise ValueError("invalid model") try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, - model_type=model_setting['model_type'], - provider=model_setting['provider'], - model=model_setting['model'] + model_type=model_setting["model_type"], + provider=model_setting["provider"], + model=model_setting["model"], ) except Exception: logging.warning(f"{model_setting['model_type']} save error") - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -85,14 +86,9 @@ def get(self, provider): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_provider( - tenant_id=tenant_id, - provider=provider - ) + models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) @setup_required @login_required @@ -104,61 +100,66 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') - parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') - parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") + parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() model_load_balancing_service = ModelLoadBalancingService() - if ('load_balancing' in args and args['load_balancing'] and - 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): - if 'configs' not in args['load_balancing']: - raise ValueError('invalid load balancing configs') + if ( + "load_balancing" in args + and args["load_balancing"] + and "enabled" in args["load_balancing"] + and args["load_balancing"]["enabled"] + ): + if "configs" not in args["load_balancing"]: + raise ValueError("invalid load balancing configs") # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - configs=args['load_balancing']['configs'] + model=args["model"], + model_type=args["model_type"], + configs=args["load_balancing"]["configs"], ) # enable load balancing model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) else: # disable load balancing model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get('config_from', '') != 'predefined-model': + if args.get("config_from", "") != "predefined-model": model_provider_service = ModelProviderService() try: model_provider_service.save_model_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: + logging.exception(f"save model credentials error: {ex}") raise ValueError(str(ex)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 @setup_required @login_required @@ -170,24 +171,26 @@ def delete(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.remove_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderModelCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -195,38 +198,34 @@ def get(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, - provider=provider, - model_type=args['model_type'], - model=args['model'] + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] ) model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) return { "credentials": credentials, - "load_balancing": { - "enabled": is_load_balancing_enabled, - "configs": load_balancing_configs - } + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, } class ModelProviderModelEnableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -234,24 +233,26 @@ def patch(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.enable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelDisableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -259,24 +260,26 @@ def patch(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.disable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelValidateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -284,10 +287,16 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -299,48 +308,42 @@ def post(self, provider: str): model_provider_service.model_credentials_validate( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderModelParameterRuleApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( - tenant_id=tenant_id, - provider=provider, - model=args['model'] + tenant_id=tenant_id, provider=provider, model=args["model"] ) - return jsonable_encoder({ - "data": parameter_rules - }) + return jsonable_encoder({"data": parameter_rules}) class ModelProviderAvailableModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -348,27 +351,31 @@ def get(self, model_type): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type( - tenant_id=tenant_id, - model_type=model_type - ) - - return jsonable_encoder({ - "data": models - }) - - -api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') -api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', - endpoint='model-provider-model-enable') -api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', - endpoint='model-provider-model-disable') -api.add_resource(ModelProviderModelCredentialApi, - '/workspaces/current/model-providers//models/credentials') -api.add_resource(ModelProviderModelValidateApi, - '/workspaces/current/model-providers//models/credentials/validate') - -api.add_resource(ModelProviderModelParameterRuleApi, - '/workspaces/current/model-providers//models/parameter-rules') -api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/') -api.add_resource(DefaultModelApi, '/workspaces/current/default-model') + models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) + + return jsonable_encoder({"data": models}) + + +api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") +api.add_resource( + ModelProviderModelEnableApi, + "/workspaces/current/model-providers//models/enable", + endpoint="model-provider-model-enable", +) +api.add_resource( + ModelProviderModelDisableApi, + "/workspaces/current/model-providers//models/disable", + endpoint="model-provider-model-disable", +) +api.add_resource( + ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" +) +api.add_resource( + ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" +) + +api.add_resource( + ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" +) +api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") +api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index bafeabb08ae2c7..c41a898fdca27e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -28,10 +28,18 @@ def get(self): tenant_id = current_user.current_tenant_id req = reqparse.RequestParser() - req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + req.add_argument( + "type", + type=str, + choices=["builtin", "model", "api", "workflow"], + required=False, + nullable=True, + location="args", + ) args = req.parse_args() - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -41,11 +49,14 @@ def get(self, provider): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( - user_id, - tenant_id, - provider, - )) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_tool_provider_tools( + user_id, + tenant_id, + provider, + ) + ) + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id @@ -63,7 +74,8 @@ def post(self, provider): tenant_id, provider, ) - + + class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -84,9 +96,10 @@ def post(self, provider): user_id, tenant_id, provider, - args['credentials'], + args["credentials"], ) - + + class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -101,6 +114,7 @@ def get(self, provider): provider, ) + class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -108,6 +122,7 @@ def get(self, provider): icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) + class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) - parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) + parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args['provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args.get('privacy_policy', ''), - args.get('custom_disclaimer', ''), - args.get('labels', []), + args["provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args.get("privacy_policy", ""), + args.get("custom_disclaimer", ""), + args.get("labels", []), ) + class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='args') + parser.add_argument("url", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, - args['url'], + args["url"], ) - + + class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -171,15 +188,18 @@ def get(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( - user_id, - tenant_id, - args['provider'], - )) + return jsonable_encoder( + ApiToolManageService.list_api_tool_provider_tools( + user_id, + tenant_id, + args["provider"], + ) + ) + class ToolApiProviderUpdateApi(Resource): @setup_required @@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args['provider'], - args['original_provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args['privacy_policy'], - args['custom_disclaimer'], - args.get('labels', []), + args["provider"], + args["original_provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args["privacy_policy"], + args["custom_disclaimer"], + args.get("labels", []), ) + class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -252,16 +274,17 @@ def get(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) + class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.parser_api_schema( - schema=args['schema'], + schema=args["schema"], ) + class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args['provider_name'] if args['provider_name'] else '', - args['tool_name'], - args['credentials'], - args['parameters'], - args['schema_type'], - args['schema'], + args["provider_name"] if args["provider_name"] else "", + args["tool_name"], + args["credentials"], + args["parameters"], + args["schema_type"], + args["schema"], ) + class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( user_id, tenant_id, - args['workflow_app_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_app_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + args = reqparser.parse_args() - if not args['workflow_tool_id']: - raise ValueError('incorrect workflow_tool_id') - + if not args["workflow_tool_id"]: + raise ValueError("incorrect workflow_tool_id") + return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_tool_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") args = reqparser.parse_args() return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - + + class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -416,28 +445,29 @@ def get(self): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') - parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") + parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() - if args.get('workflow_tool_id'): + if args.get("workflow_tool_id"): tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - elif args.get('workflow_app_id'): + elif args.get("workflow_app_id"): tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args['workflow_app_id'], + args["workflow_app_id"], ) else: - raise ValueError('incorrect workflow_tool_id or workflow_app_id') + raise ValueError("incorrect workflow_tool_id or workflow_app_id") return jsonable_encoder(tool) - + + class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -447,15 +477,18 @@ def get(self): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( - user_id, - tenant_id, - args['workflow_tool_id'], - )) + return jsonable_encoder( + WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args["workflow_tool_id"], + ) + ) + class ToolBuiltinListApi(Resource): @setup_required @@ -465,11 +498,17 @@ def get(self): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolApiListApi(Resource): @setup_required @login_required @@ -478,11 +517,17 @@ def get(self): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in ApiToolManageService.list_api_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -491,11 +536,17 @@ def get(self): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolLabelsApi(Resource): @setup_required @login_required @@ -503,36 +554,41 @@ class ToolLabelsApi(Resource): def get(self): return jsonable_encoder(ToolLabelsService.list_tool_labels()) + # tool provider -api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') +api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') -api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') -api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') -api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') -api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') -api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") +api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" +) +api.add_resource( + ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin//credentials_schema" +) +api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") # api tool provider -api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') -api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') -api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') -api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') -api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') -api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') -api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") +api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") +api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") +api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") +api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") +api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") +api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") +api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") # workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') -api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') -api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') -api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') -api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') +api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") +api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") +api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") +api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") +api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") -api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') -api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') +api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") +api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") +api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file +api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7a11a45ae81900..623f0b8b74dfdb 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -26,39 +26,34 @@ from services.workspace_service import WorkspaceService provider_fields = { - 'provider_name': fields.String, - 'provider_type': fields.String, - 'is_valid': fields.Boolean, - 'token_is_set': fields.Boolean, + "provider_name": fields.String, + "provider_type": fields.String, + "is_valid": fields.Boolean, + "token_is_set": fields.Boolean, } tenant_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'role': fields.String, - 'in_trial': fields.Boolean, - 'trial_end_reason': fields.String, - 'custom_config': fields.Raw(attribute='custom_config'), + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "role": fields.String, + "in_trial": fields.Boolean, + "trial_end_reason": fields.String, + "custom_config": fields.Raw(attribute="custom_config"), } tenants_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'current': fields.Boolean + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "current": fields.Boolean, } -workspace_fields = { - 'id': fields.String, - 'name': fields.String, - 'status': fields.String, - 'created_at': TimestampField -} +workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} class TenantListApi(Resource): @@ -71,7 +66,7 @@ def get(self): for tenant in tenants: if tenant.id == current_user.current_tenant_id: tenant.current = True # Set current=True for current tenant - return {'workspaces': marshal(tenants, tenants_fields)}, 200 + return {"workspaces": marshal(tenants, tenants_fields)}, 200 class WorkspaceListApi(Resource): @@ -79,31 +74,37 @@ class WorkspaceListApi(Resource): @admin_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\ - .paginate(page=args['page'], per_page=args['limit']) + tenants = ( + db.session.query(Tenant) + .order_by(Tenant.created_at.desc()) + .paginate(page=args["page"], per_page=args["limit"]) + ) has_more = False - if len(tenants.items) == args['limit']: + if len(tenants.items) == args["limit"]: current_page_first_tenant = tenants[-1] - rest_count = db.session.query(Tenant).filter( - Tenant.created_at < current_page_first_tenant.created_at, - Tenant.id != current_page_first_tenant.id - ).count() + rest_count = ( + db.session.query(Tenant) + .filter( + Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id + ) + .count() + ) if rest_count > 0: has_more = True total = db.session.query(Tenant).count() return { - 'data': marshal(tenants.items, workspace_fields), - 'has_more': has_more, - 'limit': args['limit'], - 'page': args['page'], - 'total': total - }, 200 + "data": marshal(tenants.items, workspace_fields), + "has_more": has_more, + "limit": args["limit"], + "page": args["page"], + "total": total, + }, 200 class TenantApi(Resource): @@ -112,8 +113,8 @@ class TenantApi(Resource): @account_initialization_required @marshal_with(tenant_fields) def get(self): - if request.path == '/info': - logging.warning('Deprecated URL /info was used.') + if request.path == "/info": + logging.warning("Deprecated URL /info was used.") tenant = current_user.current_tenant @@ -125,7 +126,7 @@ def get(self): tenant = tenants[0] # else, raise Unauthorized else: - raise Unauthorized('workspace is archived') + raise Unauthorized("workspace is archived") return WorkspaceService.get_tenant_info(tenant), 200 @@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('tenant_id', type=str, required=True, location='json') + parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args['tenant_id']) + TenantService.switch_tenant(current_user, args["tenant_id"]) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant + new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + + return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): parser = reqparse.RequestParser() - parser.add_argument('remove_webapp_brand', type=bool, location='json') - parser.add_argument('replace_webapp_logo', type=str, location='json') + parser.add_argument("remove_webapp_brand", type=bool, location="json") + parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { - 'remove_webapp_brand': args['remove_webapp_brand'], - 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , + "remove_webapp_brand": args["remove_webapp_brand"], + "replace_webapp_logo": args["replace_webapp_logo"] + if args["replace_webapp_logo"] is not None + else tenant.custom_config_dict.get("replace_webapp_logo"), } tenant.custom_config_dict = custom_config_dict db.session.commit() - return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - + return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} + class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() - extension = file.filename.split('.')[-1] - if extension.lower() not in ['svg', 'png']: + extension = file.filename.split(".")[-1] + if extension.lower() not in ["svg", "png"]: raise UnsupportedFileTypeError() try: @@ -201,14 +204,14 @@ def post(self): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - return { 'id': upload_file.id }, 201 + + return {"id": upload_file.id}, 201 -api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants -api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants -api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info -api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated -api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config') -api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload') +api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants +api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants +api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info +api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated +api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant +api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") +api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3baf69acfd2d5f..7667b30e348a98 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -16,7 +16,7 @@ def decorated(*args, **kwargs): # check account initialization account = current_user - if account.status == 'uninitialized': + if account.status == "uninitialized": raise AccountNotInitializedError() return view(*args, **kwargs) @@ -27,7 +27,7 @@ def decorated(*args, **kwargs): def only_edition_cloud(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'CLOUD': + if dify_config.EDITION != "CLOUD": abort(404) return view(*args, **kwargs) @@ -38,7 +38,7 @@ def decorated(*args, **kwargs): def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": abort(404) return view(*args, **kwargs) @@ -46,8 +46,7 @@ def decorated(*args, **kwargs): return decorated -def cloud_edition_billing_resource_check(resource: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -58,23 +57,23 @@ def decorated(*args, **kwargs): vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit - if resource == 'members' and 0 < members.limit <= members.size: - abort(403, error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: - abort(403, error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - abort(403, error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + if resource == "members" and 0 < members.limit <= members.size: + abort(403, "The number of members has reached the limit of your subscription.") + elif resource == "apps" and 0 < apps.limit <= apps.size: + abort(403, "The number of apps has reached the limit of your subscription.") + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: + abort(403, "The capacity of the vector space has reached the limit of your subscription.") + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets - source = request.args.get('source') - if source == 'datasets': - abort(403, error_msg) + source = request.args.get("source") + if source == "datasets": + abort(403, "The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) - elif resource == 'workspace_custom' and not features.can_replace_logo: - abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: - abort(403, error_msg) + elif resource == "workspace_custom" and not features.can_replace_logo: + abort(403, "The workspace custom feature has reached the limit of your subscription.") + elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: + abort(403, "The annotation quota has reached the limit of your subscription.") else: return view(*args, **kwargs) @@ -85,16 +84,18 @@ def decorated(*args, **kwargs): return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': - abort(403, error_msg) + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": + abort( + 403, + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", + ) else: return view(*args, **kwargs) @@ -112,7 +113,7 @@ def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - utm_info = request.cookies.get('utm_info') + utm_info = request.cookies.get("utm_info") if utm_info: utm_info = json.loads(utm_info) diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 8d38ab9866a023..97d5c3f88fb522 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -2,7 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('files', __name__) +bp = Blueprint("files", __name__) api = ExternalApi(bp) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 247b5d45e190e5..2432285d9365a3 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -13,35 +13,30 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get('timestamp') - nonce = request.args.get('nonce') - sign = request.args.get('sign') + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") if not timestamp or not nonce or not sign: - return {'content': 'Invalid request.'}, 400 + return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview( - file_id, - timestamp, - nonce, - sign - ) + generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - + class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) custom_config = TenantService.get_custom_config(workspace_id) - webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None + webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None if not webapp_logo_file_id: - raise NotFound('webapp logo is not found') + raise NotFound("webapp logo is not found") try: generator, mimetype = FileService.get_public_image_preview( @@ -53,11 +48,11 @@ def get(self, workspace_id): return Response(generator, mimetype=mimetype) -api.add_resource(ImagePreviewApi, '/files//image-preview') -api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces//webapp-logo') +api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 5a07ad2ea51800..38ac0815daa100 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -13,36 +13,39 @@ def get(self, file_id, extension): parser = reqparse.RequestParser() - parser.add_argument('timestamp', type=str, required=True, location='args') - parser.add_argument('nonce', type=str, required=True, location='args') - parser.add_argument('sign', type=str, required=True, location='args') + parser.add_argument("timestamp", type=str, required=True, location="args") + parser.add_argument("nonce", type=str, required=True, location="args") + parser.add_argument("sign", type=str, required=True, location="args") args = parser.parse_args() - if not ToolFileManager.verify_file(file_id=file_id, - timestamp=args['timestamp'], - nonce=args['nonce'], - sign=args['sign'], + if not ToolFileManager.verify_file( + file_id=file_id, + timestamp=args["timestamp"], + nonce=args["nonce"], + sign=args["sign"], ): - raise Forbidden('Invalid request.') - + raise Forbidden("Invalid request.") + try: result = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) if not result: - raise NotFound('file is not found') - + raise NotFound("file is not found") + generator, mimetype = result except Exception: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) -api.add_resource(ToolFilePreviewApi, '/files/tools/.') + +api.add_resource(ToolFilePreviewApi, "/files/tools/.") + class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index ad49a649caab66..9f124736a966ea 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -2,8 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) from .workspace import workspace - diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 06610d89330837..914b60f26388a0 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -9,29 +9,24 @@ class EnterpriseWorkspace(Resource): - @setup_required @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('owner_email', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("owner_email", type=str, required=True, location="json") args = parser.parse_args() - account = Account.query.filter_by(email=args['owner_email']).first() + account = Account.query.filter_by(email=args["owner_email"]).first() if account is None: - return { - 'message': 'owner account not found.' - }, 404 + return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args['name']) - TenantService.create_tenant_member(tenant, account, role='owner') + tenant = TenantService.create_tenant(args["name"]) + TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) - return { - 'message': 'enterprise workspace created.' - } + return {"message": "enterprise workspace created."} -api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') +api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 5c37f5276f5562..51ffe683ff40ad 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -17,7 +17,7 @@ def decorated(*args, **kwargs): abort(404) # get header 'X-Inner-Api-Key' - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: abort(401) @@ -33,29 +33,29 @@ def decorated(*args, **kwargs): return view(*args, **kwargs) # get header 'X-Inner-Api-Key' - authorization = request.headers.get('Authorization') + authorization = request.headers.get("Authorization") if not authorization: return view(*args, **kwargs) - parts = authorization.split(':') + parts = authorization.split(":") if len(parts) != 2: return view(*args, **kwargs) user_id, token = parts - if ' ' in user_id: - user_id = user_id.split(' ')[1] + if " " in user_id: + user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") - data_to_sign = f'DIFY {user_id}' + data_to_sign = f"DIFY {user_id}" - signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) - signature = b64encode(signature.digest()).decode('utf-8') + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + signature = b64encode(signature.digest()).decode("utf-8") if signature != token: return view(*args, **kwargs) - kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 082660a8915aa0..ad39c160ac2669 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -2,7 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('service_api', __name__, url_prefix='/v1') +bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 3b3cf1b026d568..ecc2d73deb49fb 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,3 @@ - from flask_restful import Resource, fields, marshal_with from configs import dify_config @@ -13,32 +12,30 @@ class AppParameterApi(Resource): """Resource for app variables.""" variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @validate_app_token @@ -56,30 +53,35 @@ def get(self, app_model: App): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -89,16 +91,14 @@ def get(self, app_model: App): """Get app meta""" return AppService().get_app_meta(app_model) + class AppInfoApi(Resource): @validate_app_token def get(self, app_model: App): - """Get app infomation""" - return { - 'name':app_model.name, - 'description':app_model.description - } + """Get app information""" + return {"name": app_model.name, "description": app_model.description} -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMetaApi, '/meta') -api.add_resource(AppInfoApi, '/info') +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMetaApi, "/meta") +api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 3c009af343582d..85aab047a78239 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,14 +33,10 @@ class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -74,30 +70,32 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ def post(self, app_model: App, end_user: EndUser): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 2511f46bafacc6..f1771baf314b47 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -33,21 +33,21 @@ class CompletionApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" - args['auto_generate_name'] = False + args["auto_generate_name"] = False try: response = AppGenerateService.generate( @@ -84,12 +84,12 @@ def post(self, app_model: App, end_user: EndUser): class CompletionStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(Resource): @@ -100,25 +100,21 @@ def post(self, app_model: App, end_user: EndUser): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') - parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -153,10 +149,10 @@ def post(self, app_model: App, end_user: EndUser, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 44bda8e771c97c..734027a1c51e5c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ class ConversationApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): @@ -23,17 +22,26 @@ def get(self, app_model: App, end_user: EndUser): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() try: return ConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], - invoke_from=InvokeFrom.SERVICE_API + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -53,11 +61,10 @@ def delete(self, app_model: App, end_user: EndUser, c_id): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ConversationRenameApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): @@ -68,22 +75,16 @@ def post(self, app_model: App, end_user: EndUser, c_id): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='conversation_name') -api.add_resource(ConversationApi, '/conversations') -api.add_resource(ConversationDetailApi, '/conversations/', endpoint='conversation_detail') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") +api.add_resource(ConversationApi, "/conversations") +api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ac9edb1b4f6cbe..ca91da80c19f8e 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -2,104 +2,108 @@ class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 5dbc1b1d1bbe0d..e0a772eb31ddc1 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -16,15 +16,13 @@ class FileApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if not file.mimetype: @@ -43,4 +41,4 @@ def post(self, app_model: App, end_user: EndUser): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index c8b44cfa38114c..b39aaf7dd804d0 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -17,61 +17,59 @@ class MessageListApi(Resource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'message_files': fields.List(fields.String, attribute='files') + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "message_files": fields.List(fields.String, attribute="files"), } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @@ -82,14 +80,15 @@ def get(self, app_model: App, end_user: EndUser): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -102,15 +101,15 @@ def post(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageSuggestedApi(Resource): @@ -123,22 +122,19 @@ def get(self, app_model: App, end_user: EndUser, message_id): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.SERVICE_API + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") except SuggestedQuestionsAfterAnswerDisabledError: - raise BadRequest("Message Not Exists.") + raise BadRequest("Suggested Questions Is Disabled.") except Exception: logging.exception("internal server error.") raise InternalServerError() - return {'result': 'success', 'data': questions} + return {"result": "success", "data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageSuggestedApi, '/messages//suggested') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageSuggestedApi, "/messages//suggested") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 9446f9d5886fa2..5822e0921bc70a 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -30,19 +30,20 @@ logger = logging.getLogger(__name__) workflow_run_fields = { - 'id': fields.String, - 'workflow_id': fields.String, - 'status': fields.String, - 'inputs': fields.Raw, - 'outputs': fields.Raw, - 'error': fields.String, - 'total_steps': fields.Integer, - 'total_tokens': fields.Integer, - 'created_at': fields.DateTime, - 'finished_at': fields.DateTime, - 'elapsed_time': fields.Float, + "id": fields.String, + "workflow_id": fields.String, + "status": fields.String, + "inputs": fields.Raw, + "outputs": fields.Raw, + "error": fields.String, + "total_steps": fields.Integer, + "total_tokens": fields.Integer, + "created_at": fields.DateTime, + "finished_at": fields.DateTime, + "elapsed_time": fields.Float, } + class WorkflowRunDetailApi(Resource): @validate_app_token @marshal_with(workflow_run_fields) @@ -56,6 +57,8 @@ def get(self, app_model: App, workflow_id: str): workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() return workflow_run + + class WorkflowRunApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): @@ -67,20 +70,16 @@ def post(self, app_model: App, end_user: EndUser): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - streaming = args.get('response_mode') == 'streaming' + streaming = args.get("response_mode") == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -111,11 +110,9 @@ def post(self, app_model: App, end_user: EndUser, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowRunDetailApi, '/workflows/run/') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowRunDetailApi, "/workflows/run/") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 8dd16c0787cbca..c2c0672a0346dd 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -10,13 +10,13 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from libs.login import current_user -from models.dataset import Dataset +from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name @@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource): def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") - datasets, total = DatasetService.get_datasets(page, limit, provider, - tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -51,47 +45,59 @@ def get(self, tenant_id): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + item["embedding_available"] = True + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, + ) args = parser.parse_args() try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + indexing_technique=args["indexing_technique"], + account=current_user, + permission=args["permission"], ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 200 + class DatasetApi(DatasetApiResource): """Resource for dataset.""" @@ -103,7 +109,7 @@ def delete(self, _, dataset_id): dataset_id (UUID): The ID of the dataset to be deleted. Returns: - dict: A dictionary with a key 'result' and a value 'success' + dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. @@ -115,11 +121,12 @@ def delete(self, _, dataset_id): try: if DatasetService.delete_dataset(dataset_id_str, current_user): - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') + +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac1ea820a646ba..fb48a6c76c05e1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -27,47 +27,40 @@ class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("text", type=str, required=True, nullable=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -76,60 +69,49 @@ def post(self, tenant_id, dataset_id): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('text', type=str, required=False, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("text", type=str, required=False, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if args['text']: - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + if args["text"]: + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -137,65 +119,53 @@ def post(self, tenant_id, dataset_id, document_id): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if not dataset.indexing_technique and not args.get('indexing_technique'): - raise ValueError('indexing_technique is required.') + raise ValueError("Dataset is not exist.") + if not dataset.indexing_technique and not args.get("indexing_technique"): + raise ValueError("indexing_technique is required.") # save file info - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -204,63 +174,49 @@ def post(self, tenant_id, dataset_id): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if 'file' in request.files: + raise ValueError("Dataset is not exist.") + if "file" in request.files: # save file info - file = request.files['file'] - + file = request.files["file"] if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -268,16 +224,13 @@ def post(self, tenant_id, dataset_id, document_id): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 @@ -289,13 +242,10 @@ def delete(self, tenant_id, dataset_id, document_id): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") document = DocumentService.get_document(dataset.id, document_id) @@ -311,44 +261,39 @@ def delete(self, tenant_id, dataset_id, document_id): # delete document DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) query = query.order_by(desc(Document.created_at)) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { - 'data': marshal(documents, document_fields), - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": marshal(documents, document_fields), + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response @@ -360,38 +305,36 @@ def get(self, tenant_id, dataset_id, batch): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # get documents documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data -api.add_resource(DocumentAddByTextApi, '/datasets//document/create_by_text') -api.add_resource(DocumentAddByFileApi, '/datasets//document/create_by_file') -api.add_resource(DocumentUpdateByTextApi, '/datasets//documents//update_by_text') -api.add_resource(DocumentUpdateByFileApi, '/datasets//documents//update_by_file') -api.add_resource(DocumentDeleteApi, '/datasets//documents/') -api.add_resource(DocumentListApi, '/datasets//documents') -api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') +api.add_resource(DocumentAddByTextApi, "/datasets//document/create_by_text") +api.add_resource(DocumentAddByFileApi, "/datasets//document/create_by_file") +api.add_resource(DocumentUpdateByTextApi, "/datasets//documents//update_by_text") +api.add_resource(DocumentUpdateByFileApi, "/datasets//documents//update_by_file") +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentListApi, "/datasets//documents") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e77693b6c9495c..5ff5e08c7245e1 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -2,78 +2,78 @@ class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0849eb72ba8f74..5e10f3b48c3f9f 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -21,112 +21,102 @@ class SegmentApi(DatasetApiResource): """Resource for segments.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args parser = reqparse.RequestParser() - parser.add_argument('segments', type=list, required=False, nullable=True, location='json') + parser.add_argument("segments", type=list, required=False, nullable=True, location="json") args = parser.parse_args() - for args_item in args['segments']: - SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args['segments'], document, dataset) - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form - }, 200 + if args["segments"] is not None: + for args_item in args["segments"]: + SegmentService.segment_create_args_validate(args_item, document) + segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 + else: + return {"error": "Segemtns is required"}, 400 def get(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) parser = reqparse.RequestParser() - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - status_list = args['status'] - keyword = args['keyword'] + status_list = args["status"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) total = query.count() segments = query.order_by(DocumentSegment.position).all() - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'total': total - }, 200 + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200 class DatasetSegmentApi(DatasetApiResource): @@ -134,48 +124,41 @@ def delete(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -183,35 +166,34 @@ def post(self, tenant_id, dataset_id, document_id, segment_id): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # validate args parser = reqparse.RequestParser() - parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') + parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segment'], document) - segment = SegmentService.update_segment(args['segment'], segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + SegmentService.segment_create_args_validate(args["segment"], document) + segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 -api.add_resource(SegmentApi, '/datasets//documents//segments') -api.add_resource(DatasetSegmentApi, '/datasets//documents//segments/') +api.add_resource(SegmentApi, "/datasets//documents//segments") +api.add_resource( + DatasetSegmentApi, "/datasets//documents//segments/" +) diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index c910063ebd83d1..d24c4597e210fb 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -13,4 +13,4 @@ def get(self): } -api.add_resource(IndexApi, '/') +api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 819512edf059bb..b935b23ed645c6 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -21,9 +21,10 @@ class WhereisUserArg(Enum): """ Enum for whereis_user_arg. """ - QUERY = 'query' - JSON = 'json' - FORM = 'form' + + QUERY = "query" + JSON = "json" + FORM = "form" class FetchUserArg(BaseModel): @@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - api_token = validate_and_get_api_token('app') + api_token = validate_and_get_api_token("app") app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") - if app_model.status != 'normal': + if app_model.status != "normal": raise Forbidden("The app's status is abnormal.") if not app_model.enable_api: @@ -51,15 +52,15 @@ def decorated_view(*args, **kwargs): if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get('user') + user_id = request.args.get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get('user') + user_id = request.get_json().get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get('user') + user_id = request.form.get("user") else: # use default-user user_id = None @@ -70,9 +71,10 @@ def decorated_view(*args, **kwargs): if user_id: user_id = str(user_id) - kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) return view_func(*args, **kwargs) + return decorated_view if view is None: @@ -81,9 +83,7 @@ def decorated_view(*args, **kwargs): return decorator(view) -def cloud_edition_billing_resource_check(resource: str, - api_token_type: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) @@ -95,34 +95,36 @@ def decorated(*args, **kwargs): vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota - if resource == 'members' and 0 < members.limit <= members.size: - raise Forbidden(error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: - raise Forbidden(error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - raise Forbidden(error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - raise Forbidden(error_msg) + if resource == "members" and 0 < members.limit <= members.size: + raise Forbidden("The number of members has reached the limit of your subscription.") + elif resource == "apps" and 0 < apps.limit <= apps.size: + raise Forbidden("The number of apps has reached the limit of your subscription.") + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: + raise Forbidden("The capacity of the vector space has reached the limit of your subscription.") + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + raise Forbidden("The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) return view(*args, **kwargs) + return decorated + return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - api_token_type: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': - raise Forbidden(error_msg) + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": + raise Forbidden( + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan." + ) else: return view(*args, **kwargs) @@ -132,17 +134,20 @@ def decorated(*args, **kwargs): return interceptor + def validate_dataset_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - api_token = validate_and_get_api_token('dataset') - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == api_token.tenant_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role.in_(['owner'])) \ - .filter(Tenant.status == TenantStatus.NORMAL) \ - .one_or_none() # TODO: only owner information is required, so only one is returned. + api_token = validate_and_get_api_token("dataset") + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == api_token.tenant_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role.in_(["owner"])) + .filter(Tenant.status == TenantStatus.NORMAL) + .one_or_none() + ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() @@ -156,6 +161,7 @@ def decorated(*args, **kwargs): else: raise Unauthorized("Tenant does not exist.") return view(api_token.tenant_id, *args, **kwargs) + return decorated if view: @@ -170,20 +176,24 @@ def validate_and_get_api_token(scope=None): """ Validate and get API token. """ - auth_header = request.headers.get('Authorization') - if auth_header is None or ' ' not in auth_header: + auth_header = request.headers.get("Authorization") + if auth_header is None or " " not in auth_header: raise Unauthorized("Authorization header must be provided and start with 'Bearer'") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': + if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = db.session.query(ApiToken).filter( - ApiToken.token == auth_token, - ApiToken.type == scope, - ).first() + api_token = ( + db.session.query(ApiToken) + .filter( + ApiToken.token == auth_token, + ApiToken.type == scope, + ) + .first() + ) if not api_token: raise Unauthorized("Access token is invalid") @@ -199,23 +209,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] Create or update session terminal based on user ID. """ if not user_id: - user_id = 'DEFAULT-USER' + user_id = "DEFAULT-USER" - end_user = db.session.query(EndUser) \ + end_user = ( + db.session.query(EndUser) .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == "service_api", + ) + .first() + ) if end_user is None: end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='service_api', - is_anonymous=True if user_id == 'DEFAULT-USER' else False, - session_id=user_id + type="service_api", + is_anonymous=True if user_id == "DEFAULT-USER" else False, + session_id=user_id, ) db.session.add(end_user) db.session.commit() diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index aa19bdc0349fdc..630b9468a71315 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,7 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('web', __name__, url_prefix='/api') +bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index f4db82552c983a..aabca93338f17c 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -10,33 +10,32 @@ class AppParameterApi(WebApiResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -53,30 +52,35 @@ def get(self, app_model: App, end_user): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -86,5 +90,5 @@ def get(self, app_model: App, end_user): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMeta, '/meta') \ No newline at end of file +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMeta, "/meta") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0e905f905a7a2f..d062d2893b7ab6 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,14 +31,10 @@ class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -70,34 +66,36 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): from flask_restful import reqparse + try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get( - 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ def post(self, app_model: App, end_user): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 948d5fabb5328d..0837eedfb068a9 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -15,6 +15,7 @@ ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.wraps import WebApiResource from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -24,34 +25,30 @@ from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.llm import InvokeRateLimitError # define completion api for user class CompletionApi(WebApiResource): - def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -79,12 +76,12 @@ def post(self, app_model, end_user): class CompletionStopApi(WebApiResource): def post(self, app_model, end_user, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(WebApiResource): @@ -94,25 +91,21 @@ def post(self, app_model, end_user): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -129,6 +122,8 @@ def post(self, app_model, end_user): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -146,10 +141,10 @@ def post(self, app_model, end_user, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index b83ea3a52596b1..6bbfa94c2756a5 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -15,7 +15,6 @@ class ConversationListApi(WebApiResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -23,23 +22,32 @@ def get(self, app_model, end_user): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.WEB_APP, pinned=pinned, + sort_by=args["sort_by"], ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -62,7 +70,6 @@ def delete(self, app_model, end_user, c_id): class ConversationRenameApi(WebApiResource): - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -72,24 +79,17 @@ def post(self, app_model, end_user, c_id): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(WebApiResource): - def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: @@ -117,8 +117,8 @@ def patch(self, app_model, end_user, c_id): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='web_conversation_name') -api.add_resource(ConversationListApi, '/conversations') -api.add_resource(ConversationApi, '/conversations/') -api.add_resource(ConversationPinApi, '/conversations//pin') -api.add_resource(ConversationUnPinApi, '/conversations//unpin') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") +api.add_resource(ConversationListApi, "/conversations") +api.add_resource(ConversationApi, "/conversations/") +api.add_resource(ConversationPinApi, "/conversations//pin") +api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index bc87f510512a83..9fe5d08d54a12d 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -2,122 +2,134 @@ class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your Workflow app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class WebSSOAuthRequiredError(BaseHTTPException): - error_code = 'web_sso_auth_required' + error_code = "web_sso_auth_required" description = "Web SSO authentication required." code = 401 + + +class InvokeRateLimitError(BaseHTTPException): + """Raised when the Invoke returns rate limit error.""" + + error_code = "rate_limit_error" + description = "Rate Limit Error" + code = 429 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 69b38faaf655e8..0563ed22382e6b 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -9,4 +9,4 @@ def get(self): return FeatureService.get_system_features().model_dump() -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index ca83f6037a3cfa..253b1d511c95c8 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -10,14 +10,13 @@ class FileApi(WebApiResource): - @marshal_with(file_fields) def post(self, app_model, end_user): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -32,4 +31,4 @@ def post(self, app_model, end_user): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 865d2270ad8d10..56aaaa930a4a87 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -33,48 +33,46 @@ class MessageListApi(WebApiResource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(message_infinite_scroll_pagination_fields) @@ -84,14 +82,15 @@ def get(self, app_model, end_user): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -103,29 +102,31 @@ def post(self, app_model, end_user, message_id): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(WebApiResource): def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -133,7 +134,7 @@ def get(self, app_model, end_user, message_id): user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) @@ -166,10 +167,7 @@ def get(self, app_model, end_user, message_id): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) except MessageNotExistsError: raise NotFound("Message not found") @@ -189,10 +187,10 @@ def get(self, app_model, end_user, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') -api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") +api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index ccc8683a799dc5..a01ffd861230a5 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -9,37 +9,37 @@ from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" - def get(self): + def get(self): system_features = FeatureService.get_system_features() - if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() - - app_code = request.headers.get('X-App-Code') + app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized('X-App-Code header is missing.') + raise Unauthorized("X-App-Code header is missing.") + + if system_features.sso_enforced_for_web: + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter( - Site.code == app_code, - Site.status == 'normal' - ).first() + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: raise NotFound() # get app from db and check if it is normal and enable_site app_model = db.session.query(App).filter(App.id == site.app_id).first() - if not app_model or app_model.status != 'normal' or not app_model.enable_site: + if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='browser', + type="browser", is_anonymous=True, session_id=generate_session_id(), ) @@ -49,20 +49,20 @@ def get(self): payload = { "iss": site.app_id, - 'sub': 'Web API Passport', - 'app_id': site.app_id, - 'app_code': app_code, - 'end_user_id': end_user.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": app_code, + "end_user_id": end_user.id, } tk = PassportService().issue(payload) return { - 'access_token': tk, + "access_token": tk, } -api.add_resource(PassportResource, '/passport') +api.add_resource(PassportResource, "/passport") def generate_session_id(): @@ -71,7 +71,6 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser) \ - .filter(EndUser.session_id == session_id).count() + existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() if existing_count == 0: return session_id diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e17869ffdbf8e8..8253f5fc57d676 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -10,67 +10,65 @@ from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, end_user, args['message_id']) + SavedMessageService.save(app_model, end_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(WebApiResource): def delete(self, app_model, end_user, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, end_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/saved-messages') -api.add_resource(SavedMessageApi, '/saved-messages/') +api.add_resource(SavedMessageListApi, "/saved-messages") +api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 99ec86e935e333..2b4d0e7630c6e7 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from werkzeug.exceptions import Forbidden @@ -6,6 +5,7 @@ from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from libs.helper import AppIconUrlField from models.account import TenantStatus from models.model import Site from services.feature_service import FeatureService @@ -15,39 +15,41 @@ class AppSiteApi(WebApiResource): """Resource for app sites.""" model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "pre_prompt": fields.String, } site_fields = { - 'title': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'default_language': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "title": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "default_language": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } app_fields = { - 'app_id': fields.String, - 'end_user_id': fields.String, - 'enable_site': fields.Boolean, - 'site': fields.Nested(site_fields), - 'model_config': fields.Nested(model_config_fields, allow_null=True), - 'plan': fields.String, - 'can_replace_logo': fields.Boolean, - 'custom_config': fields.Raw(attribute='custom_config'), + "app_id": fields.String, + "end_user_id": fields.String, + "enable_site": fields.Boolean, + "site": fields.Nested(site_fields), + "model_config": fields.Nested(model_config_fields, allow_null=True), + "plan": fields.String, + "can_replace_logo": fields.Boolean, + "custom_config": fields.Raw(attribute="custom_config"), } @marshal_with(app_fields) @@ -67,7 +69,7 @@ def get(self, app_model, end_user): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, '/site') +api.add_resource(AppSiteApi, "/site") class AppSiteInfo: @@ -85,9 +87,13 @@ def __init__(self, tenant, app, site, end_user, can_replace_logo): if can_replace_logo: base_url = dify_config.FILES_URL - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) self.custom_config = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 77c468e4179334..55b0c3e2ab34c5 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -33,17 +33,13 @@ def post(self, app_model: App, end_user: EndUser): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=True + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=True ) return helper.compact_generate_response(response) @@ -73,10 +69,8 @@ def post(self, app_model: App, end_user: EndUser, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index f5ab49d7e17290..93dc691d623949 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -8,6 +8,7 @@ from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -18,7 +19,9 @@ def decorated(*args, **kwargs): app_model, end_user = decode_jwt_token() return view(app_model, end_user, *args, **kwargs) + return decorated + if view: return decorator(view) return decorator @@ -26,56 +29,62 @@ def decorated(*args, **kwargs): def decode_jwt_token(): system_features = FeatureService.get_system_features() - + app_code = request.headers.get("X-App-Code") try: - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + app_code = decoded.get("app_code") + app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() if not app_code or not site: - raise BadRequest('Site URL is no longer valid.') + raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: - raise BadRequest('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + raise BadRequest("Site is disabled.") + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features) + _validate_web_sso_token(decoded, system_features, app_code) return app_model, end_user except Unauthorized as e: if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features): +def _validate_web_sso_token(decoded, system_features, app_code): + app_web_sso_enabled = False + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if not source or source != 'sso': - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + source = decoded.get("token_source") + if not source or source != "sso": + raise WebSSOAuthRequiredError() # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if source and source == 'sso': - raise Unauthorized('sso token expired.') + if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + source = decoded.get("token_source") + if source and source == "sso": + raise Unauthorized("sso token expired.") class WebApiResource(Resource): diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0ff9389793ac06..d8290ca608b0cb 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -64,15 +64,19 @@ def __init__(self, tenant_id: str, """ Agent runner :param tenant_id: tenant id + :param application_generate_entity: application generate entity + :param conversation: conversation :param app_config: app generate entity :param model_config: model config :param config: dataset config :param queue_manager: queue manager :param message: message :param user_id: user id - :param agent_llm_callback: agent llm callback - :param callback: callback :param memory: memory + :param prompt_messages: prompt messages + :param variables_pool: variables pool + :param db_variables: db variables + :param model_instance: model instance """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -445,7 +449,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P try: tool_responses = json.loads(agent_thought.observation) except Exception as e: - tool_responses = { tool: agent_thought.observation for tool in tools } + tool_responses = dict.fromkeys(tools, agent_thought.observation) for tool in tools: # generate a uuid for tool call diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 06492bb12fbd06..89c948d2e29f5f 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -292,6 +292,8 @@ def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, handle invoke action :param action: action :param tool_instances: tool instances + :param message_file_ids: message file ids + :param trace_manager: trace manager :return: observation, meta """ # action is tool call, invoke tool diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ec17db5f06a30c..1a621d2090f759 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -93,6 +93,7 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: reranking_model=dataset_configs.get('reranking_model'), weights=dataset_configs.get('weights'), reranking_enabled=dataset_configs.get('reranking_enabled', True), + rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), ) ) diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 3eb006b46ec57b..15fa4d99fd7bfc 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,6 +1,6 @@ import re -from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType from core.external_data_tool.factory import ExternalDataToolFactory @@ -13,7 +13,7 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV :param config: model config args """ external_data_variables = [] - variables = [] + variable_entities = [] # old external_data_tools external_data_tools = config.get('external_data_tools', []) @@ -30,50 +30,41 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV ) # variables and external_data_tools - for variable in config.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - if 'config' not in val: + for variables in config.get('user_input_form', []): + variable_type = list(variables.keys())[0] + if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: + variable = variables[variable_type] + if 'config' not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] + variable=variable['variable'], + type=variable['type'], + config=variable['config'] ) ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, + elif variable_type in [ + VariableEntityType.TEXT_INPUT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.SELECT, ]: - variables.append( - VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - variables.append( + variable = variables[variable_type] + variable_entities.append( VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), + type=variable_type, + variable=variable.get('variable'), + description=variable.get('description'), + label=variable.get('label'), + required=variable.get('required', False), + max_length=variable.get('max_length'), + options=variable.get('options'), + default=variable.get('default'), ) ) - return variables, external_data_variables + return variable_entities, external_data_variables @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: @@ -183,4 +174,4 @@ def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: d config=config ) - return config, ["external_data_tools"] \ No newline at end of file + return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index a490ddd67089f4..bbb10d3d76429e 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -3,8 +3,9 @@ from pydantic import BaseModel +from core.file.file_obj import FileExtraConfig from core.model_runtime.entities.message_entities import PromptMessageRole -from models.model import AppMode +from models import AppMode class ModelConfigEntity(BaseModel): @@ -81,43 +82,29 @@ def value_of(cls, value: str) -> 'PromptType': advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntityType(str, Enum): + TEXT_INPUT = "text-input" + SELECT = "select" + PARAGRAPH = "paragraph" + NUMBER = "number" + EXTERNAL_DATA_TOOL = "external-data-tool" + + class VariableEntity(BaseModel): """ Variable Entity. """ - class Type(Enum): - TEXT_INPUT = 'text-input' - SELECT = 'select' - PARAGRAPH = 'paragraph' - NUMBER = 'number' - - @classmethod - def value_of(cls, value: str) -> 'VariableEntity.Type': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid variable type value {value}') variable: str label: str description: Optional[str] = None - type: Type + type: VariableEntityType required: bool = False max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None hint: Optional[str] = None - @property - def name(self) -> str: - return self.variable - class ExternalDataVariableEntity(BaseModel): """ @@ -200,11 +187,6 @@ class TracingConfigEntity(BaseModel): tracing_provider: str -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - image_config: Optional[dict[str, Any]] = None class AppAdditionalFeatures(BaseModel): @@ -256,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str \ No newline at end of file + workflow_id: str diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 86799fb1abe133..3da3c2eddb83f3 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import FileExtraConfig +from core.file.file_obj import FileExtraConfig class FileUploadConfigManager: diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 0141dbec58de6e..e7c9ebe09749de 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,10 +4,12 @@ import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -18,20 +20,45 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models.workflow import ConversationVariable, Workflow logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate( self, app_model: App, workflow: Workflow, @@ -39,7 +66,7 @@ def generate( args: dict, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ): """ Generate App response. @@ -66,8 +93,9 @@ def generate( # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # parse files files = args['files'] if args.get('files') else [] @@ -89,7 +117,8 @@ def generate( ) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) + user_id = user.id if isinstance(user, Account) else user.session_id + trace_manager = TraceQueueManager(app_model.id, user_id) if invoke_from == InvokeFrom.DEBUGGER: # always enable retriever resource in debugger mode @@ -112,7 +141,6 @@ def generate( contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( - app_model=app_model, workflow=workflow, user=user, invoke_from=invoke_from, @@ -120,14 +148,13 @@ def generate( conversation=conversation, stream=stream ) - + def single_iteration_generate(self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): """ Generate App response. @@ -140,18 +167,19 @@ def single_iteration_generate(self, app_model: App, """ if not node_id: raise ValueError('node_id is required') - + if args.get('inputs') is None: raise ValueError('inputs is required') - + extras = { "auto_generate_conversation_name": False } # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # convert to app config app_config = AdvancedChatAppConfigManager.get_app_config( @@ -179,7 +207,6 @@ def single_iteration_generate(self, app_model: App, contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( - app_model=app_model, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, @@ -188,14 +215,13 @@ def single_iteration_generate(self, app_model: App, stream=stream ) - def _generate(self, app_model: App, + def _generate(self, *, workflow: Workflow, user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation = None, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + conversation: Conversation | None = None, + stream: bool = True): is_first_conversation = False if not conversation: is_first_conversation = True @@ -210,7 +236,7 @@ def _generate(self, app_model: App, # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - db.session.refresh(conversation) + # db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -222,15 +248,69 @@ def _generate(self, app_model: App, message_id=message.id ) + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=conversation.app_id, conversation_id=conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + conversation.dialogue_count += 1 + + conversation_id = conversation.id + conversation_dialogue_count = conversation.dialogue_count + db.session.commit() + db.session.refresh(conversation) + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + user_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = application_generate_entity.user_id + + # Create a variable pool. + system_inputs = { + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + contexts.workflow_variable_pool.set(variable_pool) + # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'conversation_id': conversation.id, 'message_id': message.id, - 'user': user, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), }) worker_thread.start() @@ -243,7 +323,7 @@ def _generate(self, app_model: App, conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) return AdvancedChatAppGenerateResponseConverter.convert( @@ -254,9 +334,7 @@ def _generate(self, app_model: App, def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, - conversation_id: str, message_id: str, - user: Account, context: contextvars.Context) -> None: """ Generate worker in a new thread. @@ -283,8 +361,7 @@ def _generate_worker(self, flask_app: Flask, user_id=application_generate_entity.user_id ) else: - # get conversation and message - conversation = self._get_conversation(conversation_id) + # get message message = self._get_message(message_id) # chatbot app @@ -292,7 +369,6 @@ def _generate_worker(self, flask_app: Flask, runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - conversation=conversation, message=message ) except GenerateTaskStoppedException: @@ -306,7 +382,7 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -315,14 +391,17 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.close() - def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False) \ - -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def _handle_advanced_chat_response( + self, + *, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -342,7 +421,7 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 18db0ab22d4ded..5dc03979cf3b4b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -16,12 +16,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models import App, Message, Workflow logger = logging.getLogger(__name__) @@ -31,10 +29,12 @@ class AdvancedChatAppRunner(AppRunner): AdvancedChat Application Runner """ - def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -48,53 +48,43 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError("App not found") + raise ValueError('App not found') workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') inputs = application_generate_entity.inputs query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id # moderation if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + message_id=message.id, ): return # annotation reply if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity + app_record=app_record, + message=message, + query=query, + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, ): return db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [ + WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) + ] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) # RUN WORKFLOW @@ -106,43 +96,29 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id - }, callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth + call_depth=application_generate_entity.call_depth, ) - def single_iteration_run(self, app_id: str, workflow_id: str, - queue_manager: AppQueueManager, - inputs: dict, node_id: str, user_id: str) -> None: + def single_iteration_run( + self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str + ) -> None: """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: - raise ValueError("App not found") - + raise ValueError('App not found') + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + raise ValueError('Workflow not initialized') + + workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, - callbacks=workflow_callbacks + workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: @@ -150,22 +126,25 @@ def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: Get workflow """ # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) # return workflow return workflow def handle_input_moderation( - self, queue_manager: AppQueueManager, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str + self, + queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> bool: """ Handle input moderation @@ -192,17 +171,20 @@ def handle_input_moderation( queue_manager=queue_manager, text=str(e), stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, ) return True return False - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + def handle_annotation_reply( + self, + app_record: App, + message: Message, + query: str, + queue_manager: AppQueueManager, + app_generate_entity: AdvancedChatAppGenerateEntity, + ) -> bool: """ Handle annotation reply :param app_record: app record @@ -217,29 +199,27 @@ def handle_annotation_reply(self, app_record: App, message=message, query=query, user_id=app_generate_entity.user_id, - invoke_from=app_generate_entity.invoke_from + invoke_from=app_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER ) self._stream_output( queue_manager=queue_manager, text=annotation_reply.content, stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, ) return True return False - def _stream_output(self, queue_manager: AppQueueManager, - text: str, - stream: bool, - stopped_by: QueueStopEvent.StopBy) -> None: + def _stream_output( + self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -250,21 +230,10 @@ def _stream_output(self, queue_manager: AppQueueManager, if stream: index = 0 for token in text: - queue_manager.publish( - QueueTextChunkEvent( - text=token - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) else: - queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) - queue_manager.publish( - QueueStopEvent(stopped_by=stopped_by), - PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a042d30e001aec..2b3596ded261e5 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ from collections.abc import Generator from typing import Any, Optional, Union, cast +import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -47,7 +48,8 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -71,7 +73,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] - _workflow_system_variables: dict[SystemVariable, Any] + # Deprecated + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__( @@ -81,7 +84,7 @@ def __init__( conversation: Conversation, message: Message, user: Union[Account, EndUser], - stream: bool + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -103,11 +106,12 @@ def __init__( self._workflow = workflow self._conversation = conversation self._message = message + # Deprecated self._workflow_system_variables = { - SystemVariable.QUERY: message.query, - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( @@ -244,7 +248,10 @@ def _process_stream_response( :return: """ for message in self._queue_manager.listen(): - if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher: + if (message.event + and getattr(message.event, 'metadata', None) + and message.event.metadata.get('is_answer_previous_node', False) + and publisher): publisher.publish(message=message) elif (hasattr(message.event, 'execution_metadata') and message.event.execution_metadata @@ -609,7 +616,9 @@ def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: if route_chunk_node_id == 'sys': # system variable - value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) + value = contexts.workflow_variable_pool.get().get(value_selector) + if value: + value = value.text elif route_chunk_node_id in self._iteration_nested_relations: # it's a iteration variable if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 53780bdfb003b2..daf37f4a50dd31 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -28,6 +28,24 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[dict, None, None]: ... + + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate(self, app_model: App, user: Union[Account, EndUser], args: Any, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6f48aa23637f7a..9e331dff4d64e2 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import AppConfig, VariableEntity +from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType class BaseAppGenerator: @@ -9,29 +9,29 @@ def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_conf user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} + filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} return filtered_inputs def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.name) + user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.name} is required in input form') + raise ValueError(f'{var.variable} is required in input form') if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? return var.default or '' if ( var.type in ( - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.SELECT, - VariableEntity.Type.PARAGRAPH, + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, ) and user_input_value and not isinstance(user_input_value, str) ): - raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") - if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): + raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") + if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: if '.' in user_input_value: @@ -39,14 +39,14 @@ def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): else: return int(user_input_value) except ValueError: - raise ValueError(f"{var.name} in input form must be a valid number") - if var.type == VariableEntity.Type.SELECT: + raise ValueError(f"{var.variable} in input form must be a valid number") + if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.name} in input form must be one of the following: {options}') - elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): + raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') + raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') return user_input_value diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 58c7d04b8348f8..2c5feaaaafb153 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -14,7 +14,6 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -27,13 +26,16 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None) -> int: """ Get pre calculate rest tokens @@ -126,7 +128,7 @@ def organize_prompt_messages(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ @@ -254,6 +256,7 @@ def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], :param invoke_result: invoke result :param queue_manager: application queue manager :param stream: stream + :param agent: agent :return: """ if not stream: @@ -276,6 +279,7 @@ def _handle_invoke_result_direct(self, invoke_result: LLMResult, Handle invoke result direct :param invoke_result: invoke result :param queue_manager: application queue manager + :param agent: agent :return: """ queue_manager.publish( @@ -291,6 +295,7 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, Handle invoke result :param invoke_result: invoke result :param queue_manager: application queue manager + :param agent: agent :return: """ model = None @@ -366,7 +371,7 @@ def moderation_for_inputs( message_id=message_id, trace_manager=app_generate_entity.trace_manager ) - + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: @@ -418,7 +423,7 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str, inputs=inputs, query=query ) - + def query_app_annotations_to_reply(self, app_record: App, message: Message, query: str, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5b896e28455340..ab15928b744b61 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -28,13 +28,31 @@ class ChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate( self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c4e1caf65a9679..c0b13b40fdb0d6 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -30,12 +30,30 @@ class CompletionAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate(self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -203,7 +221,7 @@ def generate_more_like_this(self, app_model: App, user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + -> Union[dict, Generator[str, None, None]]: """ Generate App response. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c5cd6864020b33..fceed95b91e636 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from datetime import datetime, timezone from typing import Optional, Union from sqlalchemy import and_ @@ -36,17 +37,17 @@ class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, @@ -138,6 +139,7 @@ def _init_generate_records(self, """ Initialize generate records :param application_generate_entity: application generate entity + :conversation conversation :return: """ app_config = application_generate_entity.app_config @@ -192,6 +194,9 @@ def _init_generate_records(self, db.session.add(conversation) db.session.commit() db.session.refresh(conversation) + else: + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db.session.commit() message = Message( app_id=app_config.app_id, @@ -258,7 +263,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat return introduction - def _get_conversation(self, conversation_id: str) -> Conversation: + def _get_conversation(self, conversation_id: str): """ Get conversation by conversation id :param conversation_id: conversation id @@ -270,6 +275,9 @@ def _get_conversation(self, conversation_id: str) -> Conversation: .first() ) + if not conversation: + raise ConversationNotExistsError() + return conversation def _get_message(self, message_id: str) -> Message: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index df40aec154a856..26bb6c0f4f70b8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -32,6 +32,26 @@ class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate( self, app_model: App, workflow: Workflow, @@ -107,7 +127,7 @@ def _generate( application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 24f4a83217a239..e388d0184b83b1 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,8 @@ WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -26,8 +27,7 @@ class WorkflowAppRunner: Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager) -> None: + def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: """ Run application :param application_generate_entity: application generate entity @@ -47,25 +47,36 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError("App not found") + raise ValueError('App not found') workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') inputs = application_generate_entity.inputs files = application_generate_entity.files db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [ + WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) + ] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( @@ -75,44 +86,33 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id - }, callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth + call_depth=application_generate_entity.call_depth, + variable_pool=variable_pool, ) - def single_iteration_run(self, app_id: str, workflow_id: str, - queue_manager: AppQueueManager, - inputs: dict, node_id: str, user_id: str) -> None: + def single_iteration_run( + self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str + ) -> None: """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: - raise ValueError("App not found") - + raise ValueError('App not found') + if not app_record.workflow_id: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + raise ValueError('Workflow not initialized') + + workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, - callbacks=workflow_callbacks + workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: @@ -120,11 +120,13 @@ def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: Get workflow """ # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) # return workflow return workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2b4362150fc7e7..de8542d7b9f859 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -42,7 +42,8 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -66,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, @@ -91,8 +92,8 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, self._workflow = workflow self._workflow_system_variables = { - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.USER_ID: user_id + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_id } self._task_state = WorkflowTaskState( @@ -519,7 +520,7 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: """ nodes = graph.get('nodes') - iteration_ids = [node.get('id') for node in nodes + iteration_ids = [node.get('id') for node in nodes if node.get('data', {}).get('type') in [ NodeType.ITERATION.value, NodeType.LOOP.value, @@ -530,4 +531,3 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id ] for iteration_id in iteration_ids } - \ No newline at end of file diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9a861c29e2634c..6a1ab230416d0c 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -166,4 +166,4 @@ class SingleIterationRunEntity(BaseModel): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None \ No newline at end of file + single_iteration_run: Optional[SingleIterationRunEntity] = None diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index d5cd0a589cc38a..7de06dfb9639fd 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -1,7 +1,7 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, - FileSegment, + ArraySegment, FloatSegment, IntegerSegment, NoneSegment, @@ -12,11 +12,9 @@ from .types import SegmentType from .variables import ( ArrayAnyVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -31,7 +29,6 @@ 'FloatVariable', 'ObjectVariable', 'SecretVariable', - 'FileVariable', 'StringVariable', 'ArrayAnyVariable', 'Variable', @@ -44,10 +41,9 @@ 'FloatSegment', 'ObjectSegment', 'ArrayAnySegment', - 'FileSegment', 'StringSegment', 'ArrayStringVariable', 'ArrayNumberVariable', 'ArrayObjectVariable', - 'ArrayFileVariable', + 'ArraySegment', ] diff --git a/api/core/app/segments/exc.py b/api/core/app/segments/exc.py new file mode 100644 index 00000000000000..d15d6d500ffa4a --- /dev/null +++ b/api/core/app/segments/exc.py @@ -0,0 +1,2 @@ +class VariableError(Exception): + pass diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index f62e44bf07d2f7..e6e9ce97747ce1 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -1,11 +1,11 @@ from collections.abc import Mapping from typing import Any -from core.file.file_obj import FileVar +from configs import dify_config +from .exc import VariableError from .segments import ( ArrayAnySegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -15,11 +15,9 @@ ) from .types import SegmentType from .variables import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -29,39 +27,37 @@ ) -def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable: - if (value_type := m.get('value_type')) is None: - raise ValueError('missing value type') - if not m.get('name'): - raise ValueError('missing name') - if (value := m.get('value')) is None: - raise ValueError('missing value') +def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if (value_type := mapping.get('value_type')) is None: + raise VariableError('missing value type') + if not mapping.get('name'): + raise VariableError('missing name') + if (value := mapping.get('value')) is None: + raise VariableError('missing value') match value_type: case SegmentType.STRING: - return StringVariable.model_validate(m) + result = StringVariable.model_validate(mapping) case SegmentType.SECRET: - return SecretVariable.model_validate(m) + result = SecretVariable.model_validate(mapping) case SegmentType.NUMBER if isinstance(value, int): - return IntegerVariable.model_validate(m) + result = IntegerVariable.model_validate(mapping) case SegmentType.NUMBER if isinstance(value, float): - return FloatVariable.model_validate(m) + result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise ValueError(f'invalid number value {value}') - case SegmentType.FILE: - return FileVariable.model_validate(m) + raise VariableError(f'invalid number value {value}') case SegmentType.OBJECT if isinstance(value, dict): - return ObjectVariable.model_validate( - {**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}} - ) + result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): - return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayStringVariable.model_validate(mapping) case SegmentType.ARRAY_NUMBER if isinstance(value, list): - return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): - return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) - case SegmentType.ARRAY_FILE if isinstance(value, list): - return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) - raise ValueError(f'not supported value type {value_type}') + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f'not supported value type {value_type}') + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + return result def build_segment(value: Any, /) -> Segment: @@ -74,13 +70,7 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, float): return FloatSegment(value=value) if isinstance(value, dict): - # TODO: Limit the depth of the object - obj = {k: build_segment(v) for k, v in value.items()} - return ObjectSegment(value=obj) + return ObjectSegment(value=value) if isinstance(value, list): - # TODO: Limit the depth of the array - elements = [build_segment(v) for v in value] - return ArrayAnySegment(value=elements) - if isinstance(value, FileVar): - return FileSegment(value=value) + return ArrayAnySegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 4227f154e6aaac..5c713cac6747f9 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -1,11 +1,10 @@ import json +import sys from collections.abc import Mapping, Sequence from typing import Any from pydantic import BaseModel, ConfigDict, field_validator -from core.file.file_obj import FileVar - from .types import SegmentType @@ -37,6 +36,10 @@ def log(self) -> str: def markdown(self) -> str: return str(self.value) + @property + def size(self) -> int: + return sys.getsizeof(self.value) + def to_object(self) -> Any: return self.value @@ -73,68 +76,54 @@ class IntegerSegment(Segment): value: int -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - # TODO: embed FileVar in this model. - value: FileVar - @property - def markdown(self) -> str: - return self.value.to_markdown() class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Segment] + value: Mapping[str, Any] @property def text(self) -> str: - # TODO: Process variables. return json.dumps(self.model_dump()['value'], ensure_ascii=False) @property def log(self) -> str: - # TODO: Process variables. return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) @property def markdown(self) -> str: - # TODO: Use markdown code block return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) - def to_object(self): - return {k: v.to_object() for k, v in self.value.items()} - class ArraySegment(Segment): @property def markdown(self) -> str: - return '\n'.join(['- ' + item.markdown for item in self.value]) - - def to_object(self): - return [v.to_object() for v in self.value] + items = [] + for item in self.value: + if hasattr(item, 'to_markdown'): + items.append(item.to_markdown()) + else: + items.append(str(item)) + return '\n'.join(items) class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Segment] + value: Sequence[Any] class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[StringSegment] + value: Sequence[str] class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[FloatSegment | IntegerSegment] + value: Sequence[float | int] class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[ObjectSegment] - + value: Sequence[Mapping[str, Any]] -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[FileSegment] diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index a371058ef52bac..cdd2b0b4b09191 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -10,8 +10,6 @@ class SegmentType(str, Enum): ARRAY_STRING = 'array[string]' ARRAY_NUMBER = 'array[number]' ARRAY_OBJECT = 'array[object]' - ARRAY_FILE = 'array[file]' OBJECT = 'object' - FILE = 'file' GROUP = 'group' diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index ac26e165425c3a..8fef707fcf298b 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -4,11 +4,9 @@ from .segments import ( ArrayAnySegment, - ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable): pass -class FileVariable(FileSegment, Variable): - pass - - class ObjectVariable(ObjectSegment, Variable): pass @@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): - pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index c9644c7d4cf1c0..8d91a507a9e8ee 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -48,7 +48,8 @@ ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index b4859edbd965d6..4935c43ac437e4 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -22,7 +22,8 @@ from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType from core.workflow.nodes.tool.entities import ToolNodeData @@ -40,6 +41,7 @@ WorkflowRunStatus, WorkflowRunTriggeredFrom, ) +from services.workflow_service import WorkflowService class WorkflowCycleManage(WorkflowIterationCycleManage): @@ -97,7 +99,6 @@ def _init_workflow_run(self, workflow: Workflow, def _workflow_run_success( self, workflow_run: WorkflowRun, - start_at: float, total_tokens: int, total_steps: int, outputs: Optional[str] = None, @@ -107,7 +108,6 @@ def _workflow_run_success( """ Workflow run success :param workflow_run: workflow run - :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param outputs: outputs @@ -116,7 +116,7 @@ def _workflow_run_success( """ workflow_run.status = WorkflowRunStatus.SUCCEEDED.value workflow_run.outputs = outputs - workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -139,7 +139,6 @@ def _workflow_run_success( def _workflow_run_failed( self, workflow_run: WorkflowRun, - start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, @@ -150,7 +149,6 @@ def _workflow_run_failed( """ Workflow run failed :param workflow_run: workflow run - :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param status: status @@ -159,7 +157,7 @@ def _workflow_run_failed( """ workflow_run.status = status.value workflow_run.error = error - workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -542,7 +540,6 @@ def _handle_workflow_finished( if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( workflow_run=workflow_run, - start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, status=WorkflowRunStatus.STOPPED, @@ -565,7 +562,6 @@ def _handle_workflow_finished( elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( workflow_run=workflow_run, - start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, status=WorkflowRunStatus.FAILED, @@ -583,7 +579,6 @@ def _handle_workflow_finished( workflow_run = self._workflow_run_success( workflow_run=workflow_run, - start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, outputs=outputs, diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 545f31fddfaedb..bd98c82720f039 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariableKey from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file + _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 03f8244bab212e..578996574739a8 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,7 +4,8 @@ from pydantic import BaseModel -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage _TEXT_COLOR_MAPPING = { diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index f3cf54a58ed211..778ef2e1ac42ad 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict +from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, @@ -202,7 +203,7 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value - if value == '[__HIDDEN__]' and key in original_credentials: + if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( @@ -345,7 +346,7 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value - if value == '[__HIDDEN__]' and key in original_credentials: + if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 859a747c12157f..53323a2eeb67ce 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -43,3 +43,8 @@ class ModelCurrentlyNotSupportError(Exception): Custom exception raised when the model not support """ description = "Model Currently Not Support" + + +class InvokeRateLimitError(Exception): + """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 268ef5df867988..3959f4b4a0bb61 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -1,14 +1,19 @@ import enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel -from core.app.app_config.entities import FileExtraConfig from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db -from models.model import UploadFile + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): @@ -114,6 +119,7 @@ def prompt_message_content(self) -> ImagePromptMessageContent: ) def _get_data(self, force_url: bool = False) -> Optional[str]: + from models.model import UploadFile if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 7b2f8217f9231f..085ff07cfde921 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,10 +1,11 @@ +import re from collections.abc import Mapping, Sequence from typing import Any, Union +from urllib.parse import parse_qs, urlparse import requests -from core.app.app_config.entities import FileExtraConfig -from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar from extensions.ext_database import db from models.account import Account from models.model import EndUser, MessageFile, UploadFile @@ -98,7 +99,7 @@ def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], f # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: + def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): """ transform message files @@ -143,7 +144,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: + def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): """ transform file to file obj @@ -186,6 +187,30 @@ def _check_image_remote_url(self, url): "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } + def is_s3_presigned_url(url): + try: + parsed_url = urlparse(url) + if 'amazonaws.com' not in parsed_url.netloc: + return False + query_params = parse_qs(parsed_url.query) + required_params = ['Signature', 'Expires'] + for param in required_params: + if param not in query_params: + return False + if not query_params['Expires'][0].isdigit(): + return False + signature = query_params['Signature'][0] + if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): + return False + return True + except Exception: + return False + + if is_s3_presigned_url(url): + response = requests.get(url, headers=headers, allow_redirects=True) + if response.status_code in {200, 304}: + return True, "" + response = requests.head(url, headers=headers, allow_redirects=True) if response.status_code in {200, 304}: return True, "" diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index afb2bbbbf317cd..4662ebb47a56a5 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,15 +1,13 @@ import logging -import time from enum import Enum from threading import Lock -from typing import Literal, Optional +from typing import Optional -from httpx import get, post +from httpx import Timeout, post from pydantic import BaseModel from yarl import URL from configs import dify_config -from core.helper.code_executor.entities import CodeDependency from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer @@ -17,12 +15,6 @@ logger = logging.getLogger(__name__) -# Code Executor -CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT -CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY - -CODE_EXECUTION_TIMEOUT = (10, 60) - class CodeExecutionException(Exception): pass @@ -66,18 +58,17 @@ class CodeExecutor: def execute_code(cls, language: CodeLanguage, preload: str, - code: str, - dependencies: Optional[list[CodeDependency]] = None) -> str: + code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' headers = { - 'X-Api-Key': CODE_EXECUTION_API_KEY + 'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY } data = { @@ -87,11 +78,13 @@ def execute_code(cls, 'enable_network': True } - if dependencies: - data['dependencies'] = [dependency.model_dump() for dependency in dependencies] - try: - response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) + response = post(str(url), json=data, headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None)) if response.status_code == 503: raise CodeExecutionException('Code execution service is unavailable') elif response.status_code != 200: @@ -102,7 +95,7 @@ def execute_code(cls, raise CodeExecutionException('Failed to execute code, which is likely a network issue,' ' please check if the sandbox service is running.' f' ( Error: {str(e)} )') - + try: response = response.json() except: @@ -110,16 +103,16 @@ def execute_code(cls, if (code := response.get('code')) != 0: raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") - + response = CodeExecutionResponse(**response) - + if response.data.error: raise CodeExecutionException(response.data.error) - - return response.data.stdout + + return response.data.stdout or '' @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -131,67 +124,11 @@ def execute_workflow_code_template(cls, language: CodeLanguage, code: str, input if not template_transformer: raise CodeExecutionException(f'Unsupported language {language}') - runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies) + runner, preload = template_transformer.transform_caller(code, inputs) try: - response = cls.execute_code(language, preload, runner, dependencies) + response = cls.execute_code(language, preload, runner) except CodeExecutionException as e: raise e return template_transformer.transform_response(response) - - @classmethod - def list_dependencies(cls, language: str) -> list[CodeDependency]: - if language not in cls.supported_dependencies_languages: - return [] - - with cls.dependencies_cache_lock: - if language in cls.dependencies_cache: - # check expiration - dependencies = cls.dependencies_cache[language] - if dependencies['expiration'] > time.time(): - return dependencies['data'] - # remove expired cache - del cls.dependencies_cache[language] - - dependencies = cls._get_dependencies(language) - with cls.dependencies_cache_lock: - cls.dependencies_cache[language] = { - 'data': dependencies, - 'expiration': time.time() + 60 - } - - return dependencies - - @classmethod - def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]: - """ - List dependencies - """ - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies' - - headers = { - 'X-Api-Key': CODE_EXECUTION_API_KEY - } - - running_language = cls.code_language_to_running_language.get(language) - if isinstance(running_language, Enum): - running_language = running_language.value - - data = { - 'language': running_language, - } - - try: - response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) - if response.status_code != 200: - raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running') - response = response.json() - dependencies = response.get('data', {}).get('dependencies', []) - return [ - CodeDependency(**dependency) for dependency in dependencies - if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages() - ] - except Exception as e: - logger.exception(f'Failed to list dependencies: {e}') - return [] \ No newline at end of file diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 761c0e2b2524fa..3f099b7ac5bbb4 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -2,8 +2,6 @@ from pydantic import BaseModel -from core.helper.code_executor.code_executor import CodeExecutor - class CodeNodeProvider(BaseModel): @staticmethod @@ -23,10 +21,6 @@ def get_default_code(cls) -> str: """ pass - @classmethod - def get_default_available_packages(cls) -> list[dict]: - return [p.model_dump() for p in CodeExecutor.list_dependencies(cls.get_language())] - @classmethod def get_default_config(cls) -> dict: return { @@ -50,6 +44,5 @@ def get_default_config(cls) -> dict: "children": None } } - }, - "available_dependencies": cls.get_default_available_packages(), + } } diff --git a/api/core/helper/code_executor/entities.py b/api/core/helper/code_executor/entities.py deleted file mode 100644 index cc10288521ad6b..00000000000000 --- a/api/core/helper/code_executor/entities.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class CodeDependency(BaseModel): - name: str - version: str diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index 63f48a56e2a47f..f1e5da584c660d 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -3,7 +3,7 @@ class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: str) -> str: + def format(cls, template: str, inputs: dict) -> str: """ Format template :param template: template diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index a8f8095d52b6d5..b8cb29600e106e 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -1,14 +1,9 @@ from textwrap import dedent -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer class Jinja2TemplateTransformer(TemplateTransformer): - @classmethod - def get_standard_packages(cls) -> set[str]: - return {'jinja2'} | Python3TemplateTransformer.get_standard_packages() - @classmethod def transform_response(cls, response: str) -> dict: """ diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index efcb8a9d1ebccf..923724b49d8a2d 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -13,7 +13,7 @@ def get_language() -> str: def get_default_code(cls) -> str: return dedent( """ - def main(arg1: int, arg2: int) -> dict: + def main(arg1: str, arg2: str) -> dict: return { "result": arg1 + arg2, } diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index 4a5fa3509325d3..75a5a44d086c3c 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -4,30 +4,6 @@ class Python3TemplateTransformer(TemplateTransformer): - @classmethod - def get_standard_packages(cls) -> set[str]: - return { - 'base64', - 'binascii', - 'collections', - 'datetime', - 'functools', - 'hashlib', - 'hmac', - 'itertools', - 'json', - 'math', - 'operator', - 'os', - 'random', - 're', - 'string', - 'sys', - 'time', - 'traceback', - 'uuid', - } - @classmethod def get_runner_script(cls) -> str: runner_script = dedent(f""" diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index da7ef469d91abf..cf66558b658f57 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -2,9 +2,6 @@ import re from abc import ABC, abstractmethod from base64 import b64encode -from typing import Optional - -from core.helper.code_executor.entities import CodeDependency class TemplateTransformer(ABC): @@ -13,12 +10,7 @@ class TemplateTransformer(ABC): _result_tag: str = '<>' @classmethod - def get_standard_packages(cls) -> set[str]: - return set() - - @classmethod - def transform_caller(cls, code: str, inputs: dict, - dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]: + def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: """ Transform code to python runner :param code: code @@ -28,14 +20,7 @@ def transform_caller(cls, code: str, inputs: dict, runner_script = cls.assemble_runner_script(code, inputs) preload_script = cls.get_preload_script() - packages = dependencies or [] - standard_packages = cls.get_standard_packages() - for package in standard_packages: - if package not in packages: - packages.append(CodeDependency(name=package, version='')) - packages = list({dep.name: dep for dep in packages if dep.name}.values()) - - return runner_script, preload_script, packages + return runner_script, preload_script @classmethod def extract_result_str_from_response(cls, response: str) -> str: diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index bf87a842c00bf4..5e5deb86b47e54 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -2,7 +2,6 @@ from extensions.ext_database import db from libs import rsa -from models.account import Tenant def obfuscated_token(token: str): @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): + from models.account import Tenant if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f'Tenant with id {tenant_id} not found') encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index dd1534c791b313..8cf184ac44c5ee 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import Any +from configs import dify_config from core.tools.utils.yaml_utils import load_yaml_file @@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> return {name: index for index, name in enumerate(positions)} +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + ) + + +def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: + """ + Pin the items in the pin list to the beginning of the position map. + Overall logic: exclude > include > pin + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :return: the sorted position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list)} + + # Add remaining positions to position map + start_idx = len(position_map) + for name in positions: + if name not in position_map: + position_map[name] = start_idx + start_idx += 1 + + return position_map + + +def is_filtered( + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], +) -> bool: + """ + Chcek if the object should be filtered out. + Overall logic: exclude > include > pin + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :param name_func: the function to get the name of the object + :param data: the data to be filtered + :return: True if the object should be filtered out, False otherwise + """ + if not data: + return False + if not include_set and not exclude_set: + return False + + name = name_func(data) + + if name in exclude_set: # exclude_set is prioritized + return True + if include_set and name not in include_set: # filter out only if include_set is not empty + return True + return False + + def sort_by_position_map( position_map: dict[str, int], data: list[Any], diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 5f7fec58337ee9..ddcd7512861da9 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -58,7 +58,8 @@ def init_app(self, app: Flask) -> None: self.moderation_config = self.init_moderation_config(config) - def init_azure_openai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_azure_openai(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TIMES if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): credentials = { @@ -145,7 +146,8 @@ def init_openai(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_anthropic(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_anthropic(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS quotas = [] @@ -180,7 +182,8 @@ def init_anthropic(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_minimax(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_minimax(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_MINIMAX_ENABLED"): quotas = [FreeHostingQuota()] @@ -197,7 +200,8 @@ def init_minimax(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_spark(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_spark(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_SPARK_ENABLED"): quotas = [FreeHostingQuota()] @@ -214,7 +218,8 @@ def init_spark(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_zhipuai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_zhipuai(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_ZHIPUAI_ENABLED"): quotas = [FreeHostingQuota()] @@ -231,7 +236,8 @@ def init_zhipuai(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_moderation_config(self, app_config: Config) -> HostedModerationConfig: + @staticmethod + def init_moderation_config(app_config: Config) -> HostedModerationConfig: if app_config.get("HOSTED_MODERATION_ENABLED") \ and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b20c6ed187f20b..062666ac6a463f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -411,7 +411,8 @@ def _extract(self, index_processor: BaseIndexProcessor, dataset_document: Datase return text_docs - def filter_string(self, text): + @staticmethod + def filter_string(text): text = re.sub(r'<\|', '<', text) text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) @@ -419,7 +420,8 @@ def filter_string(self, text): text = re.sub('\uFFFE', '', text) return text - def _get_splitter(self, processing_rule: DatasetProcessRule, + @staticmethod + def _get_splitter(processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. @@ -611,7 +613,8 @@ def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: return all_documents - def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: + @staticmethod + def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: """ Clean the document text according to the processing rules. """ @@ -640,7 +643,8 @@ def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str return text - def format_split_text(self, text): + @staticmethod + def format_split_text(text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) @@ -700,10 +704,12 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, DatasetDocument.tokens: tokens, DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, + DatasetDocument.error: None, } ) - def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): + @staticmethod + def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: @@ -714,6 +720,7 @@ def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): document_ids = [document.metadata['doc_id'] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.status == "indexing" ).update({ @@ -745,6 +752,7 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.status == "indexing" ).update({ @@ -757,13 +765,15 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d return tokens - def _check_document_paused_status(self, document_id: str): + @staticmethod + def _check_document_paused_status(document_id: str): indexing_cache_key = 'document_{}_is_paused'.format(document_id) result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedException() - def _update_document_index_status(self, document_id: str, after_indexing_status: str, + @staticmethod + def _update_document_index_status(document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None) -> None: """ Update the document indexing status. @@ -785,14 +795,16 @@ def _update_document_index_status(self, document_id: str, after_indexing_status: DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.commit() - def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: + @staticmethod + def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: """ Update the document segment by document id. """ DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): + @staticmethod + def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 0b5029460a03ff..8c13b4a45cbe6c 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -14,7 +14,8 @@ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 8e99ad3dec040b..7b1a7ada5b2225 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -44,7 +44,8 @@ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> No credentials=self.credentials ) - def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: + @staticmethod + def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ Fetch credentials from provider model bundle :param provider_model_bundle: provider model bundle @@ -63,7 +64,8 @@ def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBun return credentials - def _get_load_balancing_manager(self, configuration: ProviderConfiguration, + @staticmethod + def _get_load_balancing_manager(configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict) -> Optional["LBModelManager"]: @@ -271,9 +273,8 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option :param content_text: text content to be translated :param tenant_id: user tenant id - :param user: unique user id :param voice: model timbre - :param streaming: output is streaming + :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -369,6 +370,15 @@ def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelTyp return ModelInstance(provider_model_bundle, model) + def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Return first provider and the first model in the provider + :param tenant_id: tenant id + :param model_type: model type + :return: provider name, model name + """ + return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) + def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: """ Get default model instance @@ -401,6 +411,10 @@ def __init__(self, tenant_id: str, managed_credentials: Optional[dict] = None) -> None: """ Load balancing model manager + :param tenant_id: tenant_id + :param provider: provider + :param model_type: model_type + :param model: model name :param load_balancing_configs: all load balancing configurations :param managed_credentials: credentials if load balancing configuration name is __inherit__ """ @@ -499,13 +513,12 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: config.id ) - res = redis_client.exists(cooldown_cache_key) res = cast(bool, res) return res - @classmethod - def get_config_in_cooldown_and_ttl(cls, tenant_id: str, + @staticmethod + def get_config_in_cooldown_and_ttl(tenant_id: str, provider: str, model_type: ModelType, model: str, diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 87fe4f681ce5c7..d2076bf74a3cde 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { @@ -94,5 +93,16 @@ }, 'required': False, 'options': ['JSON', 'XML'], - } -} \ No newline at end of file + }, + DefaultParameterName.JSON_SCHEMA: { + 'label': { + 'en_US': 'JSON Schema', + }, + 'type': 'text', + 'help': { + 'en_US': 'Set a response json schema will ensure LLM to adhere it.', + 'zh_Hans': '设置返回的json schema,llm将按照它返回', + }, + 'required': False, + }, +} diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3d471787bbef8e..c257ce63d27926 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -95,6 +95,7 @@ class DefaultParameterName(Enum): FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" RESPONSE_FORMAT = "response_format" + JSON_SCHEMA = "json_schema" @classmethod def value_of(cls, value: Any) -> 'DefaultParameterName': @@ -118,6 +119,7 @@ class ParameterType(Enum): INT = "int" STRING = "string" BOOLEAN = "boolean" + TEXT = "text" class ModelPropertyKey(Enum): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 0de216bf896fc2..716bb63566c372 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -151,9 +151,9 @@ def predefined_models(self) -> list[AIModelEntity]: os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and not model_schema_yaml.startswith('_') + and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) + and model_schema_yaml.endswith('.yaml') ] # get _position.yaml file path diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 02ba0c9410f937..cfc8942c794f36 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -185,7 +185,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stream=stream, user=user ) - + model_parameters.pop("response_format") stop = stop or [] stop.extend(["\n```", "```\n"]) @@ -249,10 +249,10 @@ def new_generator(): prompt_messages=prompt_messages, input_generator=new_generator() ) - + return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], + def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] ) -> Generator[LLMResultChunk, None, None]: """ @@ -310,7 +310,7 @@ def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[Pr ) ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, + def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]) \ -> Generator[LLMResultChunk, None, None]: """ @@ -470,7 +470,7 @@ def _invoke(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ raise NotImplementedError - + @abstractmethod def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -792,6 +792,13 @@ def _validate_and_filter_model_parameters(self, model: str, model_parameters: di if not isinstance(parameter_value, str): raise ValueError(f"Model Parameter {parameter_name} should be string.") + # validate options + if parameter_rule.options and parameter_value not in parameter_rule.options: + raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") + elif parameter_rule.type == ParameterType.TEXT: + if not isinstance(parameter_value, str): + raise ValueError(f"Model Parameter {parameter_name} should be text.") + # validate options if parameter_rule.options and parameter_value not in parameter_rule.options: raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 086a189246d123..64e85d2c119ee8 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,18 +1,16 @@ -import hashlib import logging import re -import subprocess -import uuid from abc import abstractmethod from typing import Optional from pydantic import ConfigDict from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) + + class TTSModel(AIModel): """ Model class for ttstext model. @@ -37,8 +35,6 @@ def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: st :return: translated audio file """ try: - logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}") - self._is_ffmpeg_installed() return self._invoke(model=model, credentials=credentials, user=user, content_text=content_text, voice=voice, tenant_id=tenant_id) except Exception as e: @@ -75,7 +71,8 @@ def get_tts_model_voices(self, model: str, credentials: dict, language: Optional if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: voices = model_schema.model_properties[ModelPropertyKey.VOICES] if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')] + return [{'name': d['name'], 'value': d['mode']} for d in voices if + language and language in d.get('language')] else: return [{'name': d['name'], 'value': d['mode']} for d in voices] @@ -146,28 +143,3 @@ def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): if one_sentence != '': result.append(one_sentence) return result - - @staticmethod - def _is_ffmpeg_installed(): - try: - output = subprocess.check_output("ffmpeg -version", shell=True) - if "ffmpeg version" in output.decode("utf-8"): - return True - else: - raise InvokeBadRequestError("ffmpeg is not installed, " - "details: https://docs.dify.ai/getting-started/install-self-hosted" - "/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech") - except Exception: - raise InvokeBadRequestError("ffmpeg is not installed, " - "details: https://docs.dify.ai/getting-started/install-self-hosted" - "/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech") - - # Todo: To improve the streaming function - @staticmethod - def _get_file_name(file_content: str) -> str: - hash_object = hashlib.sha256(file_content.encode()) - hex_digest = hash_object.hexdigest() - - namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31') - unique_uuid = uuid.uuid5(namespace_uuid, hex_digest) - return str(unique_uuid) diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index b4e024a81ec7bb..d10314ba039e63 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -36,3 +36,4 @@ - hunyuan - siliconflow - perfxcloud +- zhinao diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 19ce401999c50f..81be1a06a7cd0e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,6 +1,6 @@ import base64 +import io import json -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -18,6 +18,7 @@ ) from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout +from PIL import Image from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -462,7 +463,8 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png new file mode 100644 index 00000000000000..4b941654a78c15 Binary files /dev/null and b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png new file mode 100644 index 00000000000000..ca3043dc8dcb19 Binary files /dev/null and b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py new file mode 100644 index 00000000000000..75d21d1ce9875e --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class AzureAIStudioProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml new file mode 100644 index 00000000000000..9e17ba088480db --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml @@ -0,0 +1,65 @@ +provider: azure_ai_studio +label: + zh_Hans: Azure AI Studio + en_US: Azure AI Studio +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Azure AI Studio + zh_Hans: Azure AI Studio +background: "#93c5fd" +help: + title: + en_US: How to deploy customized model on Azure AI Studio + zh_Hans: 如何在Azure AI Studio上的私有化部署的模型 + url: + en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models + zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models +supported_model_types: + - llm + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: endpoint + label: + en_US: Azure AI Studio Endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输入你的Azure AI Studio推理端点 + en_US: 'Enter your API Endpoint, eg: https://example.com' + - variable: api_key + required: true + label: + en_US: API Key + zh_Hans: API Key + type: secret-input + placeholder: + en_US: Enter your Azure AI Studio API Key + zh_Hans: 在此输入您的 Azure AI Studio API Key + show_on: + - variable: __model_type + value: llm + - variable: jwt_token + required: true + label: + en_US: JWT Token + zh_Hans: JWT令牌 + type: secret-input + placeholder: + en_US: Enter your Azure AI Studio JWT Token + zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key + show_on: + - variable: __model_type + value: rerank diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py new file mode 100644 index 00000000000000..42eae6c1e50fed --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py @@ -0,0 +1,334 @@ +import logging +from collections.abc import Generator +from typing import Any, Optional, Union + +from azure.ai.inference import ChatCompletionsClient +from azure.ai.inference.models import StreamingChatCompletionsUpdate +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ( + ClientAuthenticationError, + DecodeError, + DeserializationError, + HttpResponseError, + ResourceExistsError, + ResourceModifiedError, + ResourceNotFoundError, + ResourceNotModifiedError, + SerializationError, + ServiceRequestError, + ServiceResponseError, +) + +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class AzureAIStudioLargeLanguageModel(LargeLanguageModel): + """ + Model class for Azure AI Studio large language model. + """ + + client: Any = None + + from azure.ai.inference.models import StreamingChatCompletionsUpdate + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + + if not self.client: + endpoint = credentials.get("endpoint") + api_key = credentials.get("api_key") + self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key)) + + messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages] + + payload = { + "messages": messages, + "max_tokens": model_parameters.get("max_tokens", 4096), + "temperature": model_parameters.get("temperature", 0), + "top_p": model_parameters.get("top_p", 1), + "stream": stream, + } + + if stop: + payload["stop"] = stop + + if tools: + payload["tools"] = [tool.model_dump() for tool in tools] + + try: + response = self.client.complete(**payload) + + if stream: + return self._handle_stream_response(response, model, prompt_messages) + else: + return self._handle_non_stream_response(response, model, prompt_messages, credentials) + except Exception as e: + raise self._transform_invoke_error(e) + + def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator: + for chunk in response: + if isinstance(chunk, StreamingChatCompletionsUpdate): + if chunk.choices: + delta = chunk.choices[0].delta + if delta.content: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=delta.content, tool_calls=[]), + ), + ) + + def _handle_non_stream_response( + self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict + ) -> LLMResult: + assistant_text = response.choices[0].message.content + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) + usage = self._calc_response_usage( + model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens + ) + result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage) + + if hasattr(response, "system_fingerprint"): + result.system_fingerprint = response.system_fingerprint + + return result + + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: + """ + Invoke result generator + + :param result: result generator + :return: result generator + """ + callbacks = callbacks or [] + prompt_message = AssistantPromptMessage(content="") + usage = None + system_fingerprint = None + real_model = model + + try: + for chunk in result: + if isinstance(chunk, dict): + content = chunk["choices"][0]["message"]["content"] + usage = chunk["usage"] + chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content, tool_calls=[]), + ), + system_fingerprint=chunk.get("system_fingerprint"), + ) + + yield chunk + + self._trigger_new_chunk_callbacks( + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + prompt_message.content += chunk.delta.message.content + real_model = chunk.model + if hasattr(chunk.delta, "usage"): + usage = chunk.delta.usage + + if chunk.system_fingerprint: + system_fingerprint = chunk.system_fingerprint + except Exception as e: + raise self._transform_invoke_error(e) + + self._trigger_after_invoke_callbacks( + model=model, + result=LLMResult( + model=real_model, + prompt_messages=prompt_messages, + message=prompt_message, + usage=usage if usage else LLMUsage.empty_usage(), + system_fingerprint=system_fingerprint, + ), + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # Implement token counting logic here + # Might need to use a tokenizer specific to the Azure AI Studio model + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + endpoint = credentials.get("endpoint") + api_key = credentials.get("api_key") + client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key)) + client.get_model_info() + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ServiceRequestError, + ], + InvokeServerUnavailableError: [ + ServiceResponseError, + ], + InvokeAuthorizationError: [ + ClientAuthenticationError, + ], + InvokeBadRequestError: [ + HttpResponseError, + DecodeError, + ResourceExistsError, + ResourceNotFoundError, + ResourceModifiedError, + ResourceNotModifiedError, + SerializationError, + DeserializationError, + ], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + Used to define customizable model schema + """ + rules = [ + ParameterRule( + name="temperature", + type=ParameterType.FLOAT, + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), + ), + ParameterRule( + name="top_p", + type=ParameterType.FLOAT, + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), + ), + ParameterRule( + name="max_tokens", + type=ParameterType.INT, + use_template="max_tokens", + min=1, + default=512, + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), + ] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=[], + model_properties={}, + parameter_rules=rules, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py new file mode 100644 index 00000000000000..6ed7ab277cde02 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py @@ -0,0 +1,164 @@ +import json +import logging +import os +import ssl +import urllib.request +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + + +class AzureRerankModel(RerankModel): + """ + Model class for Azure AI Studio rerank model. + """ + + def _allow_self_signed_https(self, allowed): + # bypass the server certificate verification on client side + if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None): + ssl._create_default_https_context = ssl._create_unverified_context + + def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str): + # self._allow_self_signed_https(True) # Enable if using self-signed certificate + + data = {"inputs": query_input, "docs": docs} + + body = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + + req = urllib.request.Request(endpoint, body, headers) + + try: + with urllib.request.urlopen(req) as response: + result = response.read() + return json.loads(result) + except urllib.error.HTTPError as error: + logger.error(f"The request failed with status code: {error.code}") + logger.error(error.info()) + logger.error(error.read().decode("utf8", "ignore")) + raise + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + endpoint = credentials.get("endpoint") + api_key = credentials.get("jwt_token") + + if not endpoint or not api_key: + raise ValueError("Azure endpoint and API key must be provided in credentials") + + result = self._azure_rerank(query, docs, endpoint, api_key) + logger.info(f"Azure rerank result: {result}") + + rerank_documents = [] + for idx, (doc, score_dict) in enumerate(zip(docs, result)): + score = score_dict["score"] + rerank_document = RerankDocument(index=idx, text=doc, score=score) + + if score_threshold is None or score >= score_threshold: + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.score, reverse=True) + + if top_n: + rerank_documents = rerank_documents[:top_n] + + return RerankResult(model=model, docs=rerank_documents) + + except Exception as e: + logger.exception(f"Exception in Azure rerank: {e}") + raise + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [urllib.error.URLError], + InvokeServerUnavailableError: [urllib.error.HTTPError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 3d2bac1c310277..f9ddd86f68a9ec 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -70,7 +70,7 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: st # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - # max font is 4096,there is 3500 limit for each request + # max length is 4096 characters, there is 3500 limit for each request max_length = 3500 if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index ff34a116c78e05..3f7266f6002025 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,8 +1,8 @@ # standard import import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -17,6 +17,7 @@ ServiceNotInRegionError, UnknownServiceError, ) +from PIL.Image import Image # local import from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -379,8 +380,11 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if not message_content.data.startswith("data:"): # fetch image data from url try: - image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + url = message_content.data + image_content = requests.get(url).content + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" + base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml index 68325765244388..6588a4b5e01468 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -5,6 +5,8 @@ label: model_type: llm features: - agent-thought + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml index 4da75b9aa30ad2..caafeadadd999e 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml @@ -5,6 +5,8 @@ label: model_type: llm features: - agent-thought + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index ebcd0af35b2138..84241fb6c877a4 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,7 @@ import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -12,6 +12,7 @@ import requests from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part +from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -371,7 +372,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml index 98655a4c9fce7c..91d0e307657911 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml @@ -19,7 +19,7 @@ parameter_rules: min: 1 max: 8192 pricing: - input: '0.05' - output: '0.1' + input: '0.59' + output: '0.79' unit: '0.000001' currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml index d85bb7709bc877..b6154f761f4b40 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml @@ -19,7 +19,7 @@ parameter_rules: min: 1 max: 8192 pricing: - input: '0.59' - output: '0.79' + input: '0.05' + output: '0.08' unit: '0.000001' currency: USD diff --git a/api/core/model_runtime/model_providers/huggingface_tei/__init__.py b/api/core/model_runtime/model_providers/huggingface_tei/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py new file mode 100644 index 00000000000000..94544662503974 --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -0,0 +1,11 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class HuggingfaceTeiProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml new file mode 100644 index 00000000000000..f3a912d84d23d3 --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml @@ -0,0 +1,36 @@ +provider: huggingface_tei +label: + en_US: Text Embedding Inference +description: + en_US: A blazing fast inference solution for text embeddings models. + zh_Hans: 用于文本嵌入模型的超快速推理解决方案。 +background: "#FFF8DC" +help: + title: + en_US: How to deploy Text Embedding Inference + zh_Hans: 如何部署 Text Embedding Inference + url: + en_US: https://github.com/huggingface/text-embeddings-inference +supported_model_types: + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 + en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py new file mode 100644 index 00000000000000..34013426de5b77 --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -0,0 +1,137 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper + + +class HuggingfaceTeiRerankModel(RerankModel): + """ + Model class for Text Embedding Inference rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + server_url = credentials['server_url'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + try: + results = TeiHelper.invoke_rerank(server_url, query, docs) + + rerank_documents = [] + for result in results: + rerank_document = RerankDocument( + index=result['index'], + text=result['text'], + score=result['score'], + ) + if score_threshold is None or result['score'] >= score_threshold: + rerank_documents.append(rerank_document) + if top_n is not None and len(rerank_documents) >= top_n: + break + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + server_url = credentials['server_url'] + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + if extra_args.model_type != 'reranker': + raise CredentialsValidateFailedError('Current model is not a rerank model') + + credentials['context_size'] = extra_args.max_input_length + + self.invoke( + model=model, + credentials=credentials, + query='Whose kasumi', + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', + 'and she leads a team named PopiParty.', + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + }, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py new file mode 100644 index 00000000000000..2aa785c89d27e6 --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -0,0 +1,183 @@ +from threading import Lock +from time import time +from typing import Optional + +import httpx +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError, MissingSchema, Timeout +from requests.sessions import Session +from yarl import URL + + +class TeiModelExtraParameter: + model_type: str + max_input_length: int + max_client_batch_size: int + + def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None: + self.model_type = model_type + self.max_input_length = max_input_length + self.max_client_batch_size = max_client_batch_size + + +cache = {} +cache_lock = Lock() + + +class TeiHelper: + @staticmethod + def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + TeiHelper._clean_cache() + with cache_lock: + if model_name not in cache: + cache[model_name] = { + 'expires': time() + 300, + 'value': TeiHelper._get_tei_extra_parameter(server_url), + } + return cache[model_name]['value'] + + @staticmethod + def _clean_cache() -> None: + try: + with cache_lock: + expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + for model_uid in expired_keys: + del cache[model_uid] + except RuntimeError as e: + pass + + @staticmethod + def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: + """ + get tei model extra parameter like model_type, max_input_length, max_batch_requests + """ + + url = str(URL(server_url) / 'info') + + # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + session = Session() + session.mount('http://', HTTPAdapter(max_retries=3)) + session.mount('https://', HTTPAdapter(max_retries=3)) + + try: + response = session.get(url, timeout=10) + except (MissingSchema, ConnectionError, Timeout) as e: + raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + if response.status_code != 200: + raise RuntimeError( + f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + ) + + response_json = response.json() + + model_type = response_json.get('model_type', {}) + if len(model_type.keys()) < 1: + raise RuntimeError('model_type is empty') + model_type = list(model_type.keys())[0] + if model_type not in ['embedding', 'reranker']: + raise RuntimeError(f'invalid model_type: {model_type}') + + max_input_length = response_json.get('max_input_length', 512) + max_client_batch_size = response_json.get('max_client_batch_size', 1) + + return TeiModelExtraParameter( + model_type=model_type, + max_input_length=max_input_length, + max_client_batch_size=max_client_batch_size + ) + + @staticmethod + def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + """ + Invoke tokenize endpoint + + Example response: + [ + [ + { + "id": 0, + "text": "", + "special": true, + "start": null, + "stop": null + }, + { + "id": 7704, + "text": "str", + "special": false, + "start": 0, + "stop": 3 + }, + < MORE TOKENS > + ] + ] + + :param server_url: server url + :param texts: texts to tokenize + """ + resp = httpx.post( + f'{server_url}/tokenize', + json={'inputs': texts}, + ) + resp.raise_for_status() + return resp.json() + + @staticmethod + def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + """ + Invoke embeddings endpoint + + Example response: + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [...], + "index": 0 + } + ], + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": 3, + "total_tokens": 3 + } + } + + :param server_url: server url + :param texts: texts to embed + """ + # Use OpenAI compatible API here, which has usage tracking + resp = httpx.post( + f'{server_url}/v1/embeddings', + json={'input': texts}, + ) + resp.raise_for_status() + return resp.json() + + @staticmethod + def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: + """ + Invoke rerank endpoint + + Example response: + [ + { + "index": 0, + "text": "Deep Learning is ...", + "score": 0.9950755 + } + ] + + :param server_url: server url + :param texts: texts to rerank + :param candidates: candidates to rerank + """ + params = {'query': query, 'texts': docs, 'return_text': True} + + response = httpx.post( + server_url + '/rerank', + json=params, + ) + response.raise_for_status() + return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..6897b87f6d7525 --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -0,0 +1,204 @@ +import time +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper + + +class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Text Embedding Inference text embedding model. + """ + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + credentials should be like: + { + 'server_url': 'server url', + 'model_uid': 'model uid', + } + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + server_url = credentials['server_url'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + + # get model properties + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + inputs = [] + indices = [] + used_tokens = 0 + + # get tokenized results from TEI + batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) + + for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): + + # Check if the number of tokens is larger than the context size + num_tokens = len(tokenize_result) + + if num_tokens >= context_size: + # Find the best cutoff point + pre_special_token_count = 0 + for token in tokenize_result: + if token['special']: + pre_special_token_count += 1 + else: + break + rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + + # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit + token_cutoff = context_size - rest_special_token_count - 20 + + # Find the cutoff index + cutpoint_token = tokenize_result[token_cutoff] + cutoff = cutpoint_token['start'] + + inputs.append(text[0: cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + try: + used_tokens = 0 + for i in _iter: + iter_texts = inputs[i : i + max_chunks] + results = TeiHelper.invoke_embeddings(server_url, iter_texts) + embeddings = results['data'] + embeddings = [embedding['embedding'] for embedding in embeddings] + batched_embeddings.extend(embeddings) + + usage = results['usage'] + used_tokens += usage['total_tokens'] + except RuntimeError as e: + raise InvokeServerUnavailableError(str(e)) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + server_url = credentials['server_url'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) + num_tokens = sum(len(tokens) for tokens in batch_tokens) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + server_url = credentials['server_url'] + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + print(extra_args) + if extra_args.model_type != 'embedding': + raise CredentialsValidateFailedError('Current model is not a embedding model') + + credentials['context_size'] = extra_args.max_input_length + credentials['max_chunks'] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=['ping']) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + }, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 6d22f9d2d66622..0bdf6ec005056b 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -214,7 +214,7 @@ def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp def _handle_chat_response(self, credentials, model, prompt_messages, response): usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens) - assistant_prompt_message = PromptMessage(role="assistant") + assistant_prompt_message = AssistantPromptMessage() assistant_prompt_message.content = response.Choices[0].Message.Content result = LLMResult( model=model, diff --git a/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v2-base-multilingual.yaml b/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v2-base-multilingual.yaml index acf576719c03c2..e6af62107eaa08 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v2-base-multilingual.yaml +++ b/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v2-base-multilingual.yaml @@ -1,4 +1,4 @@ model: jina-reranker-v2-base-multilingual model_type: rerank model_properties: - context_size: 8192 + context_size: 1024 diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b1660afafb12e4..e2d17e32575920 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map +from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -234,7 +234,7 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: ] # get _position.yaml file path - position_map = get_position_map(model_providers_path) + position_map = get_provider_position_map(model_providers_path) # traverse all model_provider_dir_paths model_providers: list[ModelProviderExtension] = [] diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 17cf65dc3adf70..c233596637fa21 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -84,7 +84,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode def _add_custom_parameters(self, credentials: dict) -> None: credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": + credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml index 0d2e51c47f9dc6..1078e84c59dca0 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 128000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.06' output: '0.06' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml index 9ff537014a86c6..9c739d05019f5f 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 32000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.024' output: '0.024' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml index 0f308d36766752..187a86999e170d 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.012' output: '0.012' diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml index 34c802c2a7afdc..41e9c2e8086e83 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml @@ -31,6 +31,14 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: Base URL, 如:https://api.moonshot.cn/v1 + en_US: Base URL, e.g. https://api.moonshot.cn/v1 model_credential_schema: model: label: diff --git a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml index 6cc197b70b7031..ad01d430d61c79 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml @@ -10,5 +10,8 @@ - mistralai/mistral-large - mistralai/mixtral-8x7b-instruct-v0.1 - mistralai/mixtral-8x22b-instruct-v0.1 +- nvidia/nemotron-4-340b-instruct +- microsoft/phi-3-medium-128k-instruct +- microsoft/phi-3-mini-128k-instruct - fuyu-8b - snowflake/arctic diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 494b7374f5ae28..bc42eaca658bac 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -34,8 +34,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): 'meta/llama-3.1-8b-instruct': '', 'meta/llama-3.1-70b-instruct': '', 'meta/llama-3.1-405b-instruct': '', - 'google/recurrentgemma-2b': '' - + 'google/recurrentgemma-2b': '', + 'nvidia/nemotron-4-340b-instruct': '', + 'microsoft/phi-3-medium-128k-instruct':'', + 'microsoft/phi-3-mini-128k-instruct':'' } def _invoke(self, model: str, credentials: dict, diff --git a/api/core/model_runtime/model_providers/nvidia/llm/nemotron-4-340b-instruct.yaml b/api/core/model_runtime/model_providers/nvidia/llm/nemotron-4-340b-instruct.yaml new file mode 100644 index 00000000000000..e5537cd2fd9dc8 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/nemotron-4-340b-instruct.yaml @@ -0,0 +1,36 @@ +model: nvidia/nemotron-4-340b-instruct +label: + zh_Hans: nvidia/nemotron-4-340b-instruct + en_US: nvidia/nemotron-4-340b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 4096 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/phi-3-medium-128k-instruct.yaml b/api/core/model_runtime/model_providers/nvidia/llm/phi-3-medium-128k-instruct.yaml new file mode 100644 index 00000000000000..0c5538d1350613 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/phi-3-medium-128k-instruct.yaml @@ -0,0 +1,36 @@ +model: microsoft/phi-3-medium-128k-instruct +label: + zh_Hans: microsoft/phi-3-medium-128k-instruct + en_US: microsoft/phi-3-medium-128k-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 4096 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/phi-3-mini-128k-instruct.yaml b/api/core/model_runtime/model_providers/nvidia/llm/phi-3-mini-128k-instruct.yaml new file mode 100644 index 00000000000000..1eb1c51d01157c --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/phi-3-mini-128k-instruct.yaml @@ -0,0 +1,36 @@ +model: microsoft/phi-3-mini-128k-instruct +label: + zh_Hans: microsoft/phi-3-mini-128k-instruct + en_US: microsoft/phi-3-mini-128k-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 4096 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 069de9acec0a92..9e26d35afc9437 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -72,7 +72,7 @@ def _invoke(self, model: str, credentials: dict, num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: - cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start inputs.append(text[0: cutoff]) else: diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 91b9215829b00b..ac7313aaa1bf0b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,6 +1,8 @@ - gpt-4 - gpt-4o - gpt-4o-2024-05-13 +- gpt-4o-2024-08-06 +- chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 - gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml new file mode 100644 index 00000000000000..98e236650c9e73 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -0,0 +1,44 @@ +model: chatgpt-4o-latest +label: + zh_Hans: chatgpt-4o-latest + en_US: chatgpt-4o-latest +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml index d6338c3d194204..6eb15e6c0df396 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml @@ -37,7 +37,7 @@ parameter_rules: - text - json_object pricing: - input: '0.001' - output: '0.002' + input: '0.0005' + output: '0.0015' unit: '0.001' currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml new file mode 100644 index 00000000000000..7e430c51a710fc --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -0,0 +1,47 @@ +model: gpt-4o-2024-08-06 +label: + zh_Hans: gpt-4o-2024-08-06 + en_US: gpt-4o-2024-08-06 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object + - json_schema + - name: json_schema + use_template: json_schema +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml index 6f23e0647d6eec..03e28772e607fc 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index b97fbf8aabcae4..23dcf85085e123 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index aae2729bdfb042..06135c958463e8 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,3 +1,4 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union, cast @@ -544,13 +545,18 @@ def _chat_generate(self, model: str, credentials: dict, response_format = model_parameters.get("response_format") if response_format: - if response_format == "json_object": - response_format = {"type": "json_object"} + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not currect json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: - response_format = {"type": "text"} - - model_parameters["response_format"] = response_format - + model_parameters["response_format"] = {"type": response_format} extra_model_kwargs = {} @@ -922,11 +928,14 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" if model.startswith('ft:'): model = model.split(':')[1] + # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. + if model == "chatgpt-4o-latest": + model = "gpt-4o" + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -946,7 +955,7 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], raise NotImplementedError( f"get_num_tokens_from_messages() is not presently implemented " f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "See https://platform.openai.com/docs/advanced-usage/managing-tokens for " "information on how messages are converted to tokens." ) num_tokens = 0 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e5cc884b6de315..6279125f46ffa0 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -150,9 +150,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except json.JSONDecodeError as e: raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') - if (completion_type is LLMMode.CHAT and json_result['object'] == ''): + if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''): json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): + elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''): json_result['object'] = 'text_completion' if (completion_type is LLMMode.CHAT @@ -428,7 +428,7 @@ def get_tool_call(tool_call_id: str): if new_tool_call.function.arguments: tool_call.function.arguments += new_tool_call.function.arguments - finish_reason = 'Unknown' + finish_reason = None # The default value of finish_reason is None for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() @@ -437,6 +437,8 @@ def get_tool_call(tool_call_id: str): if chunk.startswith(':'): continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + continue try: chunk_json = json.loads(decoded_chunk) @@ -647,7 +649,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O else: raise ValueError(f"Got unknown type {message}") - if message.name: + if message.name and message_dict.get("role", "") != "tool": message_dict["name"] = message.name return message_dict diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 69bed9603902a6..88c76fe16ef733 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -7,6 +7,7 @@ description: supported_model_types: - llm - text-embedding + - speech2text configurate_methods: - customizable-model model_credential_schema: @@ -61,6 +62,22 @@ model_credential_schema: zh_Hans: 模型上下文长度 en_US: Model context size required: true + show_on: + - variable: __model_type + value: llm + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + show_on: + - variable: __model_type + value: text-embedding type: text-input default: '4096' placeholder: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/__init__.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py new file mode 100644 index 00000000000000..00702ba9367cf4 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -0,0 +1,63 @@ +from typing import IO, Optional +from urllib.parse import urljoin + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat + + +class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): + """ + Model class for OpenAI Compatible Speech to text model. + """ + + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + headers = {} + + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" + endpoint_url = urljoin(endpoint_url, "audio/transcriptions") + + payload = {"model": model} + files = [("file", file)] + response = requests.post(endpoint_url, headers=headers, data=payload, files=files) + + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + response_data = response.json() + return response_data["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, "rb") as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 3467cd6dfd97f9..363054b084a69c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -76,7 +76,7 @@ def _invoke(self, model: str, credentials: dict, num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: - cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start inputs.append(text[0: cutoff]) else: diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml new file mode 100644 index 00000000000000..cf2de0f73a0b84 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-2024-08-06 +label: + zh_Hans: gpt-4o-2024-08-06 + en_US: gpt-4o-2024-08-06 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml new file mode 100644 index 00000000000000..87712874b9f5c6 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml @@ -0,0 +1,61 @@ +model: Llama3-Chinese_v2 +label: + en_US: Llama3-Chinese_v2 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml new file mode 100644 index 00000000000000..f16f3de60b8d30 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3-70B-Instruct-GPTQ-Int4 +label: + en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 1024 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml new file mode 100644 index 00000000000000..21267c240bc0e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3-8B-Instruct +label: + en_US: Meta-Llama-3-8B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml new file mode 100644 index 00000000000000..80c7ec40f22cb9 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +label: + en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 410960 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-8B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-8B-Instruct.yaml new file mode 100644 index 00000000000000..bbab46344c3a12 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-8B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3.1-8B-Instruct +label: + en_US: Meta-Llama-3.1-8B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.1 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml index af6fb91cd9463c..ec6d9bcc149fdf 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml @@ -55,7 +55,8 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml index 4ab9a800557337..b561a53039a78e 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml @@ -55,7 +55,8 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml index 4a8b1cf4797476..841dd97f353b04 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml @@ -6,7 +6,7 @@ features: - agent-thought model_properties: mode: chat - context_size: 8192 + context_size: 2048 parameter_rules: - name: temperature use_template: temperature @@ -55,7 +55,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml index b076504493fe81..33d5d12b223d2b 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml @@ -6,7 +6,7 @@ features: - agent-thought model_properties: mode: completion - context_size: 8192 + context_size: 32768 parameter_rules: - name: temperature use_template: temperature @@ -55,7 +55,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml index e24a69fe635626..62255cc7d2e566 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml @@ -8,12 +8,12 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 8192 + context_size: 2048 parameter_rules: - name: temperature use_template: temperature type: float - default: 0.3 + default: 0.7 min: 0.0 max: 2.0 help: @@ -57,7 +57,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct.yaml new file mode 100644 index 00000000000000..cea6560295267a --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Qwen2-72B-Instruct +label: + en_US: Qwen2-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml index e3d804729d3990..2f3f1f0225d712 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml @@ -8,7 +8,7 @@ features: - stream-tool-call model_properties: mode: completion - context_size: 8192 + context_size: 32768 parameter_rules: - name: temperature use_template: temperature @@ -57,7 +57,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml index b95f6bdc1b3288..2c9eac0e49a4d7 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml @@ -1,6 +1,15 @@ +- Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +- Meta-Llama-3.1-8B-Instruct +- Meta-Llama-3-70B-Instruct-GPTQ-Int4 +- Meta-Llama-3-8B-Instruct - Qwen2-72B-Instruct-GPTQ-Int4 +- Qwen2-72B-Instruct - Qwen2-7B -- Qwen1.5-110B-Chat-GPTQ-Int4 +- Qwen-14B-Chat-Int4 - Qwen1.5-72B-Chat-GPTQ-Int4 - Qwen1.5-7B -- Qwen-14B-Chat-Int4 +- Qwen1.5-110B-Chat-GPTQ-Int4 +- deepseek-v2-chat +- deepseek-v2-lite-chat +- Llama3-Chinese_v2 +- chatglm3-6b diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml new file mode 100644 index 00000000000000..f9c26b7f9002c7 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml @@ -0,0 +1,61 @@ +model: chatglm3-6b +label: + en_US: chatglm3-6b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml new file mode 100644 index 00000000000000..078922ef95b087 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml @@ -0,0 +1,61 @@ +model: deepseek-v2-chat +label: + en_US: deepseek-v2-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml new file mode 100644 index 00000000000000..4ff3af7b517761 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml @@ -0,0 +1,61 @@ +model: deepseek-v2-lite-chat +label: + en_US: deepseek-v2-lite-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 2048 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-en-v1.5.yaml b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-en-v1.5.yaml new file mode 100644 index 00000000000000..5756fb3d149d4c --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-en-v1.5.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-large-en-v1.5 +model_type: text-embedding +model_properties: + context_size: 32768 diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-zh-v1.5.yaml b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-zh-v1.5.yaml new file mode 100644 index 00000000000000..4204ab2860248b --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-zh-v1.5.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-large-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 32768 diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 5a99ad301f36fa..11d57e3749a8f1 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -79,7 +79,7 @@ def _invoke(self, model: str, credentials: dict, num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: - cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start inputs.append(text[0: cutoff]) else: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml index 20bb0790c2234e..c2f0eb05360327 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml @@ -1,8 +1,20 @@ -- deepseek-v2-chat -- qwen2-72b-instruct -- qwen2-57b-a14b-instruct -- qwen2-7b-instruct -- yi-1.5-34b-chat -- yi-1.5-9b-chat -- yi-1.5-6b-chat -- glm4-9B-chat +- Qwen/Qwen2-72B-Instruct +- Qwen/Qwen2-57B-A14B-Instruct +- Qwen/Qwen2-7B-Instruct +- Qwen/Qwen2-1.5B-Instruct +- 01-ai/Yi-1.5-34B-Chat +- 01-ai/Yi-1.5-9B-Chat-16K +- 01-ai/Yi-1.5-6B-Chat +- THUDM/glm-4-9b-chat +- deepseek-ai/DeepSeek-V2-Chat +- deepseek-ai/DeepSeek-Coder-V2-Instruct +- internlm/internlm2_5-7b-chat +- google/gemma-2-27b-it +- google/gemma-2-9b-it +- meta-llama/Meta-Llama-3-70B-Instruct +- meta-llama/Meta-Llama-3-8B-Instruct +- meta-llama/Meta-Llama-3.1-405B-Instruct +- meta-llama/Meta-Llama-3.1-70B-Instruct +- meta-llama/Meta-Llama-3.1-8B-Instruct +- mistralai/Mixtral-8x7B-Instruct-v0.1 +- mistralai/Mistral-7B-Instruct-v0.2 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml index 3926568db6f2e9..caa6508b5ed2a2 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml @@ -1,4 +1,4 @@ -model: deepseek-ai/deepseek-v2-chat +model: deepseek-ai/DeepSeek-V2-Chat label: en_US: deepseek-ai/DeepSeek-V2-Chat model_type: llm diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml new file mode 100644 index 00000000000000..2840e3dcf4b113 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml @@ -0,0 +1,30 @@ +model: google/gemma-2-27b-it +label: + en_US: google/gemma-2-27b-it +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8196 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '1.26' + output: '1.26' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml new file mode 100644 index 00000000000000..d7e19b46f6d6f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml @@ -0,0 +1,30 @@ +model: google/gemma-2-9b-it +label: + en_US: google/gemma-2-9b-it +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8196 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml index d6a4b21b66d968..9b32a024774d06 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml @@ -1,4 +1,4 @@ -model: zhipuai/glm4-9B-chat +model: THUDM/glm-4-9b-chat label: en_US: THUDM/glm-4-9b-chat model_type: llm @@ -24,7 +24,7 @@ parameter_rules: - name: frequency_penalty use_template: frequency_penalty pricing: - input: '0.6' - output: '0.6' + input: '0' + output: '0' unit: '0.000001' currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml new file mode 100644 index 00000000000000..73ad4480aa2968 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml @@ -0,0 +1,30 @@ +model: internlm/internlm2_5-7b-chat +label: + en_US: internlm/internlm2_5-7b-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml new file mode 100644 index 00000000000000..9993d781ac8959 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml @@ -0,0 +1,30 @@ +model: meta-llama/Meta-Llama-3-70B-Instruct +label: + en_US: meta-llama/Meta-Llama-3-70B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '4.13' + output: '4.13' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml new file mode 100644 index 00000000000000..60e3764789e1f5 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml @@ -0,0 +1,30 @@ +model: meta-llama/Meta-Llama-3-8B-Instruct +label: + en_US: meta-llama/Meta-Llama-3-8B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml new file mode 100644 index 00000000000000..f992660aa2e66f --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml @@ -0,0 +1,30 @@ +model: meta-llama/Meta-Llama-3.1-405B-Instruct +label: + en_US: meta-llama/Meta-Llama-3.1-405B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '21' + output: '21' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml new file mode 100644 index 00000000000000..1c69d63a400219 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml @@ -0,0 +1,30 @@ +model: meta-llama/Meta-Llama-3.1-70B-Instruct +label: + en_US: meta-llama/Meta-Llama-3.1-70B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '4.13' + output: '4.13' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml new file mode 100644 index 00000000000000..a97002a5ca3658 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml @@ -0,0 +1,30 @@ +model: meta-llama/Meta-Llama-3.1-8B-Instruct +label: + en_US: meta-llama/Meta-Llama-3.1-8B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml new file mode 100644 index 00000000000000..27664eab6c817a --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml @@ -0,0 +1,30 @@ +model: mistralai/Mistral-7B-Instruct-v0.2 +label: + en_US: mistralai/Mistral-7B-Instruct-v0.2 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml new file mode 100644 index 00000000000000..fd7aada42848aa --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml @@ -0,0 +1,30 @@ +model: mistralai/Mixtral-8x7B-Instruct-v0.1 +label: + en_US: mistralai/Mixtral-8x7B-Instruct-v0.1 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '1.26' + output: '1.26' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml new file mode 100644 index 00000000000000..f6c976af8e7b6e --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml @@ -0,0 +1,30 @@ +model: Qwen/Qwen2-1.5B-Instruct +label: + en_US: Qwen/Qwen2-1.5B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml index 39624dc5b97c3d..a996e919ea9f27 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml @@ -1,4 +1,4 @@ -model: alibaba/Qwen2-57B-A14B-Instruct +model: Qwen/Qwen2-57B-A14B-Instruct label: en_US: Qwen/Qwen2-57B-A14B-Instruct model_type: llm diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml index fb7ff6cb14dbd7..a6e2c22dac87c0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml @@ -1,4 +1,4 @@ -model: alibaba/Qwen2-72B-Instruct +model: Qwen/Qwen2-72B-Instruct label: en_US: Qwen/Qwen2-72B-Instruct model_type: llm diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml index efda4abbd9965e..d8bea5e12927e7 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml @@ -1,4 +1,4 @@ -model: alibaba/Qwen2-7B-Instruct +model: Qwen/Qwen2-7B-Instruct label: en_US: Qwen/Qwen2-7B-Instruct model_type: llm @@ -24,7 +24,7 @@ parameter_rules: - name: frequency_penalty use_template: frequency_penalty pricing: - input: '0.35' - output: '0.35' + input: '0' + output: '0' unit: '0.000001' currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml index 38cd4197d4c3dd..fe4c8b4b3e0350 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml @@ -24,7 +24,7 @@ parameter_rules: - name: frequency_penalty use_template: frequency_penalty pricing: - input: '0.35' - output: '0.35' + input: '0' + output: '0' unit: '0.000001' currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml index 042eeea81a48f0..c61f0dc53fe6ec 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml @@ -1,4 +1,4 @@ -model: 01-ai/Yi-1.5-9B-Chat +model: 01-ai/Yi-1.5-9B-Chat-16K label: en_US: 01-ai/Yi-1.5-9B-Chat-16K model_type: llm @@ -24,7 +24,7 @@ parameter_rules: - name: frequency_penalty use_template: frequency_penalty pricing: - input: '0.42' - output: '0.42' + input: '0' + output: '0' unit: '0.000001' currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py b/api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 00000000000000..ff3635bfeb1273 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,4 @@ +model: netease-youdao/bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 512 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 00000000000000..807f531b084892 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 8192 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py new file mode 100644 index 00000000000000..683591581638e4 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -0,0 +1,87 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class SiliconflowRerankModel(RerankModel): + + def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], + score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) -> RerankResult: + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') + if base_url.endswith('/'): + base_url = base_url[:-1] + try: + response = httpx.post( + base_url + '/rerank', + json={ + "model": model, + "query": query, + "documents": docs, + "top_n": top_n, + "return_documents": True + }, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results['results']: + rerank_document = RerankDocument( + index=result['index'], + text=result['document']['text'], + score=result['relevance_score'], + ) + if score_threshold is None or result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + try: + + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index a53f16c929728e..dd0eea362a5f83 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class SiliconflowProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index cf44c185d59a21..c46a891604c480 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -12,9 +12,12 @@ help: en_US: Get your API Key from SiliconFlow zh_Hans: 从 SiliconFlow 获取 API Key url: - en_US: https://cloud.siliconflow.cn/keys + en_US: https://cloud.siliconflow.cn/account/ak supported_model_types: - llm + - text-embedding + - rerank + - speech2text configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/__init__.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml b/api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml new file mode 100644 index 00000000000000..deceaf60f4f017 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml @@ -0,0 +1,5 @@ +model: iic/SenseVoiceSmall +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: mp3,wav diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py new file mode 100644 index 00000000000000..6ad3cab5873c69 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -0,0 +1,32 @@ +from typing import IO, Optional + +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel + + +class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): + """ + Model class for Siliconflow Speech to text model. + """ + + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + return super().validate_credentials(model, credentials) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/bce-embedding-base-v1.yaml b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bce-embedding-base-v1.yaml new file mode 100644 index 00000000000000..710fbc04f6ad12 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bce-embedding-base-v1.yaml @@ -0,0 +1,5 @@ +model: netease-youdao/bce-embedding-base_v1 +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-large-en-v1.5.yaml b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-large-en-v1.5.yaml new file mode 100644 index 00000000000000..84f69b41a08c13 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-large-en-v1.5.yaml @@ -0,0 +1,5 @@ +model: BAAI/bge-large-en-v1.5 +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-large-zh-v1.5.yaml b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-large-zh-v1.5.yaml new file mode 100644 index 00000000000000..5248375d0b507e --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-large-zh-v1.5.yaml @@ -0,0 +1,5 @@ +model: BAAI/bge-large-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-m3.yaml b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-m3.yaml new file mode 100644 index 00000000000000..f0b12dd420ab2b --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-m3.yaml @@ -0,0 +1,5 @@ +model: BAAI/bge-m3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..c58765cecb9a69 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -0,0 +1,29 @@ +from typing import Optional + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel): + """ + Model class for Siliconflow text embedding model. + """ + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, texts, user) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + self._add_custom_parameters(credentials) + return super().get_num_tokens(model, credentials, texts) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml b/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml index b34433e1d4d150..2bb0c703f4aa39 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml @@ -2,5 +2,7 @@ - step-1-32k - step-1-128k - step-1-256k +- step-1-flash +- step-2-16k - step-1v-8k - step-1v-32k diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-flash.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-flash.yaml new file mode 100644 index 00000000000000..afb880f2a40bbc --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-flash.yaml @@ -0,0 +1,25 @@ +model: step-1-flash +label: + zh_Hans: step-1-flash + en_US: step-1-flash +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8000 +pricing: + input: '0.001' + output: '0.004' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml index f878ee3e56e1f9..08d6ad245d2dd6 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml @@ -5,6 +5,9 @@ label: model_type: llm features: - vision + - tool-call + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 32000 diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml index 6c3cb61d2c6621..843d14d9c67e80 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml @@ -5,6 +5,9 @@ label: model_type: llm features: - vision + - tool-call + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 8192 diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-2-16k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-2-16k.yaml new file mode 100644 index 00000000000000..6f2dabbfb0e308 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-2-16k.yaml @@ -0,0 +1,28 @@ +model: step-2-16k +label: + zh_Hans: step-2-16k + en_US: step-2-16k +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 16000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 16000 +pricing: + input: '0.038' + output: '0.120' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml new file mode 100644 index 00000000000000..aad07f56736e52 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml @@ -0,0 +1,81 @@ +model: farui-plus +label: + en_US: farui-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 12288 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: enable_search + type: boolean + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index a75db78d8ce412..4e1bb0a5a4fa2b 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -159,6 +159,8 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr """ if model in ['qwen-turbo-chat', 'qwen-plus-chat']: model = model.replace('-chat', '') + if model == 'farui-plus': + model = 'qwen-farui-plus' if model in self.tokenizers: tokenizer = self.tokenizers[model] diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml index eed09f95dedea7..f4303c53d38b80 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml @@ -2,3 +2,8 @@ model: text-embedding-v1 model_type: text-embedding model_properties: context_size: 2048 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml index db2fa861e69f90..f6be3544ed8f65 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml @@ -2,3 +2,8 @@ model: text-embedding-v2 model_type: text-embedding model_properties: context_size: 2048 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml new file mode 100644 index 00000000000000..171a379ee23f77 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml @@ -0,0 +1,9 @@ +model: text-embedding-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index c207ffc1e34bbb..97dcb72f7c1b3e 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -2,6 +2,7 @@ from typing import Optional import dashscope +import numpy as np from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( @@ -21,11 +22,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): """ def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,16 +38,44 @@ def _invoke( :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - embeddings, embedding_used_tokens = self.embed_documents( - credentials_kwargs=credentials_kwargs, - model=model, - texts=texts - ) + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + embeddings_batch, embedding_used_tokens = self.embed_documents( + credentials_kwargs=credentials_kwargs, + model=model, + texts=inputs[i : i + max_chunks], + ) + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + # calc usage + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=used_tokens + ) return TextEmbeddingResult( - embeddings=embeddings, - usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + embeddings=batched_embeddings, usage=usage, model=model ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -79,12 +108,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) + self.embed_documents( + credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: + def embed_documents( + credentials_kwargs: dict, model: str, texts: list[str] + ) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -102,16 +135,26 @@ def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> t api_key=credentials_kwargs["dashscope_api_key"], model=model, input=text, - text_type="document" + text_type="document", ) - data = response.output["embeddings"][0] - embeddings.append(data["embedding"]) - embedding_used_tokens += response.usage["total_tokens"] + if response.output and "embeddings" in response.output and response.output["embeddings"]: + data = response.output["embeddings"][0] + if "embedding" in data: + embeddings.append(data["embedding"]) + else: + raise ValueError("Embedding data is missing in the response.") + else: + raise ValueError("Response output is missing or does not contain embeddings.") + + if response.usage and "total_tokens" in response.usage: + embedding_used_tokens += response.usage["total_tokens"] + else: + raise ValueError("Response usage is missing or does not contain total tokens.") return [list(map(float, e)) for e in embeddings], embedding_used_tokens def _calc_response_usage( - self, model: str, credentials: dict, tokens: int + self, model: str, credentials: dict, tokens: int ) -> EmbeddingUsage: """ Calculate response usage @@ -125,7 +168,7 @@ def _calc_response_usage( model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -136,7 +179,7 @@ def _calc_response_usage( price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml index d4f03e1988f8b8..7992843dcb1d1d 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml @@ -1 +1 @@ -- soloar-1-mini-chat +- solar-1-mini-chat diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 8901549110ee07..1a7368a2cf843d 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -1,4 +1,5 @@ import base64 +import io import json import logging from collections.abc import Generator @@ -18,6 +19,7 @@ ) from google.cloud import aiplatform from google.oauth2 import service_account +from PIL import Image from vertexai.generative_models import HarmBlockThreshold, HarmCategory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -332,7 +334,8 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 471cb3c94e01f8..a4d89dabcbc076 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,6 +1,25 @@ import re -from collections.abc import Callable, Generator -from typing import cast +from collections.abc import Generator +from typing import Optional, cast + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -12,123 +31,195 @@ ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService +DEFAULT_V2_ENDPOINT = "maas-api.ml-platform-cn-beijing.volces.com" +DEFAULT_V3_ENDPOINT = "https://ark.cn-beijing.volces.com/api/v3" + + +class ArkClientV3: + endpoint_id: Optional[str] = None + ark: Optional[Ark] = None -class MaaSClient(MaasService): - def __init__(self, host: str, region: str): + def __init__(self, *args, **kwargs): + self.ark = Ark(*args, **kwargs) self.endpoint_id = None - super().__init__(host, region) - def set_endpoint_id(self, endpoint_id: str): - self.endpoint_id = endpoint_id + @staticmethod + def is_legacy(credentials: dict) -> bool: + # match default v2 endpoint + if ArkClientV3.is_compatible_with_legacy(credentials): + return False + # match default v3 endpoint + if credentials.get("api_endpoint_host") == DEFAULT_V3_ENDPOINT: + return False + # only v3 support api_key + if credentials.get("auth_method") == "api_key": + return False + # these cases are considered as sdk v2 + # - modified default v2 endpoint + # - modified default v3 endpoint and auth without api_key + return True - @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] - - client = cls(host, region) - client.set_endpoint_id(endpoint_id) - client.set_ak(ak) - client.set_sk(sk) - return client + @staticmethod + def is_compatible_with_legacy(credentials: dict) -> bool: + endpoint = credentials.get("api_endpoint_host") + return endpoint == DEFAULT_V2_ENDPOINT - def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: - req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], - **extra_model_kwargs, + @classmethod + def from_credentials(cls, credentials): + """Initialize the client using the credentials provided.""" + args = { + "base_url": credentials['api_endpoint_host'], + "region": credentials['volc_region'], } - if not stream: - return super().chat( - self.endpoint_id, - req, - ) - return super().stream_chat( - self.endpoint_id, - req, - ) + if credentials.get("auth_method") == "api_key": + args = { + **args, + "api_key": credentials['volc_api_key'], + } + else: + args = { + **args, + "ak": credentials['volc_access_key_id'], + "sk": credentials['volc_secret_access_key'], + } - def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } - return super().embeddings(self.endpoint_id, req) + if cls.is_compatible_with_legacy(credentials): + args = { + **args, + "base_url": DEFAULT_V3_ENDPOINT + } + + client = ArkClientV3( + **args + ) + client.endpoint_id = credentials['endpoint_id'] + return client @staticmethod - def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam: + """Converts a PromptMessage to a ChatCompletionMessageParam""" if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + content = message.content else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + content.append(ChatCompletionContentPartTextParam( + text=message_content.text, + type='text', + )) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast( ImagePromptMessageContent, message_content) image_data = re.sub( r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, - } - }) - - message_dict = {'role': ChatRole.USER, 'content': content} + content.append(ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type='image_url', + )) + message_dict = ChatCompletionUserMessageParam( + role='user', + content=content + ) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} - if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict = ChatCompletionAssistantMessageParam( + content=message.content, + role='assistant', + tool_calls=None if not message.tool_calls else [ + ChatCompletionMessageToolCallParam( + id=call.id, + function=Function( + name=call.function.name, + arguments=call.function.arguments + ), + type='function' + ) for call in message.tool_calls ] + ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = ChatCompletionSystemMessageParam( + content=message.content, + role='system' + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = ChatCompletionToolMessageParam( + content=message.content, + role='tool', + tool_call_id=message.tool_call_id + ) else: raise ValueError(f"Got unknown PromptMessage type {message}") return message_dict @staticmethod - def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: - try: - resp = fn() - except MaasException as e: - raise wrap_error(e) + def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type='function', + function=FunctionDefinition( + name=message.name, + description=message.description, + parameters=message.parameters, + ) + ) - return resp + def chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: + """Block chat""" + return self.ark.chat.completions.create( + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) - @staticmethod - def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - } + def stream_chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: + """Stream chat""" + chunks = self.ark.chat.completions.create( + stream=True, + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) + for chunk in chunks: + if not chunk.choices: + continue + yield chunk + + def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse: + return self.ark.embeddings.create(model=self.endpoint_id, input=texts) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py new file mode 100644 index 00000000000000..1978c11680f597 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -0,0 +1,134 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> 'MaaSClient': + host = credentials['api_endpoint_host'] + region = credentials['volc_region'] + ak = credentials['volc_access_key_id'] + sk = credentials['volc_secret_access_key'] + endpoint_id = credentials['endpoint_id'] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: + req = { + 'parameters': params, + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + **extra_model_kwargs, + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = { + 'input': texts + } + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, + "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError( + 'Content object type only support image_url') + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast( + ImagePromptMessageContent, message_content) + image_data = re.sub( + r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + content.append({ + 'type': 'image_url', + 'image_url': { + 'url': '', + 'image_bytes': image_data, + 'detail': message_content.detail, + } + }) + + message_dict = {'role': ChatRole.USER, 'content': content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {'role': ChatRole.ASSISTANT, + 'content': message.content} + if message.tool_calls: + message_dict['tool_calls'] = [ + { + 'name': call.function.name, + 'arguments': call.function.arguments + } for call in message.tool_calls + ] + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {'role': ChatRole.SYSTEM, + 'content': message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {'role': ChatRole.FUNCTION, + 'content': message.content, + 'name': message.tool_call_id} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasException as e: + raise wrap_error(e) + + return resp + + @staticmethod + def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py similarity index 97% rename from api/core/model_runtime/model_providers/volcengine_maas/errors.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 63397a456ece96..21ffaf1258fefd 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,4 +1,4 @@ -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException class ClientSDKRequestError(MaasException): diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 8bea30324b5ad8..996c66e604fea2 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -1,8 +1,10 @@ import logging from collections.abc import Generator +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,16 +29,21 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.llm.models import ( + get_model_config, + get_v2_req_params, + get_v3_req_params, +) logger = logging.getLogger(__name__) @@ -46,13 +53,20 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials """ - # ping + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(credentials) + return self._validate_credentials_v3(credentials) + + @staticmethod + def _validate_credentials_v2(credentials: dict) -> None: client = MaaSClient.from_credential(credentials) try: client.chat( @@ -67,21 +81,40 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except MaasException as e: raise CredentialsValidateFailedError(e.message) + @staticmethod + def _validate_credentials_v3(credentials: dict) -> None: + client = ArkClientV3.from_credentials(credentials) + try: + client.chat(max_tokens=16, temperature=0.7, top_p=0.9, + messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + except Exception as e: + raise CredentialsValidateFailedError(e) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None) -> int: - if len(prompt_messages) == 0: + if ArkClientV3.is_legacy(credentials): + return self._get_num_tokens_v2(prompt_messages) + return self._get_num_tokens_v3(prompt_messages) + + def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: return 0 - return self._num_tokens_from_messages(prompt_messages) + num_tokens = 0 + messages_dict = [ + MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) - def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: - """ - Calculate num tokens. + return num_tokens - :param messages: messages - """ + def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: + return 0 num_tokens = 0 messages_dict = [ - MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + ArkClientV3.convert_prompt_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -89,118 +122,165 @@ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: return num_tokens - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) - - req_params = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).copy() - if credentials.get('context_size'): - req_params['max_prompt_tokens'] = credentials.get('context_size') - if credentials.get('max_tokens'): - req_params['max_new_tokens'] = credentials.get('max_tokens') - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') - if stop: - req_params['stop'] = stop - + req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} - if tools: extra_model_kwargs['tools'] = [ MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools ] - resp = MaaSClient.wrap_exception( lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) - if not stream: - return self._handle_chat_response(model, credentials, prompt_messages, resp) - return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) - def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: - for index, r in enumerate(resp): - choices = r['choices'] + def _handle_stream_chat_response() -> Generator: + for index, r in enumerate(resp): + choices = r['choices'] + if not choices: + continue + choice = choices[0] + message = choice['message'] + usage = None + if r.get('usage'): + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=r['usage']['prompt_tokens'], + completion_tokens=r['usage']['completion_tokens'] + ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[] + ), + usage=usage, + finish_reason=choice.get('finish_reason'), + ), + ) + + def _handle_chat_response() -> LLMResult: + choices = resp['choices'] if not choices: - continue + raise ValueError("No choices found") + choice = choices[0] message = choice['message'] - usage = None - if r.get('usage'): - usage = self._calc_usage(model, credentials, r['usage']) - yield LLMResultChunk( + + # parse tool calls + tool_calls = [] + if message['tool_calls']: + for call in message['tool_calls']: + tool_call = AssistantPromptMessage.ToolCall( + id=call['function']['name'], + type=call['type'], + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call['function']['name'], + arguments=call['function']['arguments'] + ) + ) + tool_calls.append(tool_call) + + usage = resp['usage'] + return LLMResult( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] - ), - usage=usage, - finish_reason=choice.get('finish_reason'), + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=tool_calls, ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'] + ), ) - def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: - choices = resp['choices'] - if not choices: - return - choice = choices[0] - message = choice['message'] - - # parse tool calls - tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: - tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + if not stream: + return _handle_chat_response() + return _handle_stream_chat_response() + + def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + + client = ArkClientV3.from_credentials(credentials) + req_params = get_v3_req_params(credentials, model_parameters, stop) + if tools: + req_params['tools'] = tools + + def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: + for chunk in chunks: + if not chunk.choices: + continue + choice = chunk.choices[0] + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=choice.index, + message=AssistantPromptMessage( + content=choice.delta.content, + tool_calls=[] + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens + ) if chunk.usage else None, + finish_reason=choice.finish_reason, + ), ) - tool_calls.append(tool_call) - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=tool_calls, - ), - usage=self._calc_usage(model, credentials, resp['usage']), - ) + def _handle_chat_response(resp: ChatCompletion) -> LLMResult: + choice = resp.choices[0] + message = choice.message + # parse tool calls + tool_calls = [] + if message.tool_calls: + for call in message.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=call.id, + type=call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call.function.name, + arguments=call.function.arguments + ) + ) + tool_calls.append(tool_call) - def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: - return self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ) + usage = resp.usage + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message.content if message.content else "", + tool_calls=tool_calls, + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens + ), + ) + + if not stream: + resp = client.chat(prompt_messages, **req_params) + return _handle_chat_response(resp) + + chunks = client.stream_chat(prompt_messages, **req_params) + return _handle_stream_chat_response(chunks) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema """ - max_tokens = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens') - if credentials.get('max_tokens'): - max_tokens = int(credentials.get('max_tokens')) + model_config = get_model_config(credentials) + rules = [ ParameterRule( name='temperature', @@ -234,10 +314,10 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode name='presence_penalty', type=ParameterType.FLOAT, use_template='presence_penalty', - label={ - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', - }, + label=I18nObject( + en_US='Presence Penalty', + zh_Hans='存在惩罚', + ), min=-2.0, max=2.0, ), @@ -245,10 +325,10 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode name='frequency_penalty', type=ParameterType.FLOAT, use_template='frequency_penalty', - label={ - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', - }, + label=I18nObject( + en_US='Frequency Penalty', + zh_Hans='频率惩罚', + ), min=-2.0, max=2.0, ), @@ -257,7 +337,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode type=ParameterType.INT, use_template='max_tokens', min=1, - max=max_tokens, + max=model_config.properties.max_tokens, default=512, label=I18nObject( zh_Hans='最大生成长度', @@ -266,16 +346,9 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ), ] - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('mode'): - model_properties[ModelPropertyKey.MODE] = credentials.get('mode') - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - - model_features = ModelConfigs.get( - credentials['base_model_name'], {}).get('features', []) + model_properties = {} + model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size + model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value entity = AIModelEntity( model=model, @@ -286,7 +359,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.LLM, model_properties=model_properties, parameter_rules=rules, - features=model_features, + features=model_config.features, ) return entity diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index 3e5938f3b494af..a882f68a36cc97 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,181 +1,153 @@ +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.model_entities import ModelFeature -ModelConfigs = { - 'Doubao-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Skylark2-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4000, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-8B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-70B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-8k': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 16384, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 65536, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B-Fin': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Mistral-7B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 2048, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - } + +class ModelProperties(BaseModel): + context_size: int + max_tokens: int + mode: LLMMode + + +class ModelConfig(BaseModel): + properties: ModelProperties + features: list[ModelFeature] + + +configs: dict[str, ModelConfig] = { + 'Doubao-pro-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-pro-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-pro-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'Skylark2-pro-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'Llama3-8B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), + features=[] + ), + 'Llama3-70B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-8k': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Moonshot-v1-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Moonshot-v1-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'GLM3-130B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'GLM3-130B-Fin': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Mistral-7B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), + features=[] + ) } + + +def get_model_config(credentials: dict) -> ModelConfig: + base_model = credentials.get('base_model_name', '') + model_configs = configs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get('context_size', 0)), + max_tokens=int(credentials.get('max_tokens', 0)), + mode=LLMMode.value_of(credentials.get('mode', 'chat')), + ), + features=[] + ) + return model_configs + + +def get_v2_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_prompt_tokens'] = model_configs.properties.context_size + req_params['max_new_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_new_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('top_k'): + req_params['top_k'] = model_parameters.get('top_k') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params + + +def get_v3_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 569f89e9754454..74cf26247cb233 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -1,9 +1,30 @@ +from pydantic import BaseModel + + +class ModelProperties(BaseModel): + context_size: int + max_chunks: int + + +class ModelConfig(BaseModel): + properties: ModelProperties + + ModelConfigs = { - 'Doubao-embedding': { - 'req_params': {}, - 'model_properties': { - 'context_size': 4096, - 'max_chunks': 1, - } - }, + 'Doubao-embedding': ModelConfig( + properties=ModelProperties(context_size=4096, max_chunks=32) + ), } + + +def get_model_config(credentials: dict) -> ModelConfig: + base_model = credentials.get('base_model_name', '') + model_configs = ModelConfigs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get('context_size', 0)), + max_chunks=int(credentials.get('max_chunks', 0)), + ) + ) + return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 10b01c0d0d6401..d54aeeb0b15f89 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -22,16 +22,17 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): @@ -51,6 +52,14 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, texts, user) + + return self._generate_v3(model, credentials, texts, user) + + def _generate_v2(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) @@ -65,6 +74,23 @@ def _invoke(self, model: str, credentials: dict, return result + def _generate_v3(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + client = ArkClientV3.from_credentials(credentials) + resp = client.embeddings(texts) + + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=resp.usage.total_tokens) + + result = TextEmbeddingResult( + model=model, + embeddings=[v.embedding for v in resp.data], + usage=usage + ) + + return result + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -88,11 +114,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(model, credentials) + return self._validate_credentials_v3(model, credentials) + + def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: self._invoke(model=model, credentials=credentials, texts=['ping']) except MaasException as e: raise CredentialsValidateFailedError(e.message) + def _validate_credentials_v3(self, model: str, credentials: dict) -> None: + try: + self._invoke(model=model, credentials=credentials, texts=['ping']) + except Exception as e: + raise CredentialsValidateFailedError(e) + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -115,14 +152,11 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode """ generate custom model entities from credentials """ - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - if credentials.get('max_chunks'): - model_properties[ModelPropertyKey.MAX_CHUNKS] = int( - credentials.get('max_chunks', 4096)) + model_config = get_model_config(credentials) + model_properties = { + ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + } entity = AIModelEntity( model=model, label=I18nObject(en_US=model), diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index a00c1b79944b5a..13e00da76fb149 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -30,8 +30,28 @@ model_credential_schema: en_US: Enter your Model Name zh_Hans: 输入模型名称 credential_form_schemas: + - variable: auth_method + required: true + label: + en_US: Authentication Method + zh_Hans: 鉴权方式 + type: select + default: aksk + options: + - label: + en_US: API Key + value: api_key + - label: + en_US: Access Key / Secret Access Key + value: aksk + placeholder: + en_US: Enter your Authentication Method + zh_Hans: 选择鉴权方式 - variable: volc_access_key_id required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Access Key zh_Hans: Access Key @@ -41,6 +61,9 @@ model_credential_schema: zh_Hans: 输入您的 Access Key - variable: volc_secret_access_key required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Secret Access Key zh_Hans: Secret Access Key @@ -48,6 +71,17 @@ model_credential_schema: placeholder: en_US: Enter your Secret Access Key zh_Hans: 输入您的 Secret Access Key + - variable: volc_api_key + required: true + show_on: + - variable: auth_method + value: api_key + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your API Key + zh_Hans: 输入您的 API Key - variable: volc_region required: true label: @@ -64,7 +98,7 @@ model_credential_schema: en_US: API Endpoint Host zh_Hans: API Endpoint Host type: text-input - default: maas-api.ml-platform-cn-beijing.volces.com + default: https://ark.cn-beijing.volces.com/api/v3 placeholder: en_US: Enter your API Endpoint Host zh_Hans: 输入 API Endpoint Host diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py new file mode 100644 index 00000000000000..0230c78b757612 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -0,0 +1,198 @@ +from datetime import datetime, timedelta +from threading import Lock + +from requests import post + +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + +baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens_lock = Lock() + + +class BaiduAccessToken: + api_key: str + access_token: str + expires: datetime + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.access_token = '' + self.expires = datetime.now() + timedelta(days=3) + + @staticmethod + def _get_access_token(api_key: str, secret_key: str) -> str: + """ + request access token from Baidu + """ + try: + response = post( + url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json' + }, + ) + except Exception as e: + raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + + resp = response.json() + if 'error' in resp: + if resp['error'] == 'invalid_client': + raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') + elif resp['error'] == 'unknown_error': + raise InternalServerError(f'Internal server error: {resp["error_description"]}') + elif resp['error'] == 'invalid_request': + raise BadRequestError(f'Bad request: {resp["error_description"]}') + elif resp['error'] == 'rate_limit_exceeded': + raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') + else: + raise Exception(f'Unknown error: {resp["error_description"]}') + + return resp['access_token'] + + @staticmethod + def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + """ + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) + + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. + """ + + # loop up cache, remove expired access token + baidu_access_tokens_lock.acquire() + now = datetime.now() + for key in list(baidu_access_tokens.keys()): + token = baidu_access_tokens[key] + if token.expires < now: + baidu_access_tokens.pop(key) + + if api_key not in baidu_access_tokens: + # if access token not in cache, request it + token = BaiduAccessToken(api_key) + baidu_access_tokens[api_key] = token + # release it to enhance performance + # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock + baidu_access_tokens_lock.release() + # try to get access token + token_str = BaiduAccessToken._get_access_token(api_key, secret_key) + token.access_token = token_str + token.expires = now + timedelta(days=3) + return token + else: + # if access token in cache, return it + token = baidu_access_tokens[api_key] + baidu_access_tokens_lock.release() + return token + + +class _CommonWenxin: + api_bases = { + 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', + 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', + 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', + 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', + 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', + 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', + 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', + 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', + 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', + 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', + 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', + 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', + 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', + 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', + 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', + 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', + } + + function_calling_supports = [ + 'ernie-bot', + 'ernie-bot-8k', + 'ernie-3.5-8k', + 'ernie-3.5-8k-0205', + 'ernie-3.5-8k-1222', + 'ernie-3.5-4k-0205', + 'ernie-3.5-128k', + 'ernie-4.0-8k', + 'ernie-4.0-turbo-8k', + 'ernie-4.0-turbo-8k-preview', + 'yi_34b_chat' + ] + + api_key: str = '' + secret_key: str = '' + + def __init__(self, api_key: str, secret_key: str): + self.api_key = api_key + self.secret_key = secret_key + + @staticmethod + def _to_credential_kwargs(credentials: dict) -> dict: + credentials_kwargs = { + "api_key": credentials['api_key'], + "secret_key": credentials['secret_key'] + } + return credentials_kwargs + + def _handle_error(self, code: int, msg: str): + error_map = { + 1: InternalServerError, + 2: InternalServerError, + 3: BadRequestError, + 4: RateLimitReachedError, + 6: InvalidAuthenticationError, + 13: InvalidAPIKeyError, + 14: InvalidAPIKeyError, + 15: InvalidAPIKeyError, + 17: RateLimitReachedError, + 18: RateLimitReachedError, + 19: RateLimitReachedError, + 100: InvalidAPIKeyError, + 111: InvalidAPIKeyError, + 200: InternalServerError, + 336000: InternalServerError, + 336001: BadRequestError, + 336002: BadRequestError, + 336003: BadRequestError, + 336004: InvalidAuthenticationError, + 336005: InvalidAPIKeyError, + 336006: BadRequestError, + 336007: BadRequestError, + 336008: BadRequestError, + 336100: InternalServerError, + 336101: BadRequestError, + 336102: BadRequestError, + 336103: BadRequestError, + 336104: BadRequestError, + 336105: BadRequestError, + 336200: InternalServerError, + 336303: BadRequestError, + 337006: BadRequestError + } + + if code in error_map: + raise error_map[code](msg) + else: + raise InternalServerError(f'Unknown error: {msg}') + + def _get_access_token(self) -> str: + token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) + return token.access_token diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index bc7f29cf6ea538..8109949b1d80b7 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,102 +1,17 @@ from collections.abc import Generator -from datetime import datetime, timedelta from enum import Enum from json import dumps, loads -from threading import Lock from typing import Any, Union from requests import Response, post from core.model_runtime.entities.message_entities import PromptMessageTool -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( BadRequestError, InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, ) -# map api_key to access_token -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} -baidu_access_tokens_lock = Lock() - -class BaiduAccessToken: - api_key: str - access_token: str - expires: datetime - - def __init__(self, api_key: str) -> None: - self.api_key = api_key - self.access_token = '' - self.expires = datetime.now() + timedelta(days=3) - - def _get_access_token(api_key: str, secret_key: str) -> str: - """ - request access token from Baidu - """ - try: - response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, - ) - except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') - - resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': - raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': - raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': - raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': - raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') - else: - raise Exception(f'Unknown error: {resp["error_description"]}') - - return resp['access_token'] - - @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': - """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) - - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. - """ - - # loop up cache, remove expired access token - baidu_access_tokens_lock.acquire() - now = datetime.now() - for key in list(baidu_access_tokens.keys()): - token = baidu_access_tokens[key] - if token.expires < now: - baidu_access_tokens.pop(key) - - if api_key not in baidu_access_tokens: - # if access token not in cache, request it - token = BaiduAccessToken(api_key) - baidu_access_tokens[api_key] = token - # release it to enhance performance - # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock - baidu_access_tokens_lock.release() - # try to get access token - token_str = BaiduAccessToken._get_access_token(api_key, secret_key) - token.access_token = token_str - token.expires = now + timedelta(days=3) - return token - else: - # if access token in cache, return it - token = baidu_access_tokens[api_key] - baidu_access_tokens_lock.release() - return token - class ErnieMessage: class Role(Enum): @@ -120,49 +35,7 @@ def __init__(self, content: str, role: str = 'user') -> None: self.content = content self.role = role -class ErnieBotModel: - api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - } - - function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview' - ] - - api_key: str = '' - secret_key: str = '' - - def __init__(self, api_key: str, secret_key: str): - self.api_key = api_key - self.secret_key = secret_key +class ErnieBotModel(_CommonWenxin): def generate(self, model: str, stream: bool, messages: list[ErnieMessage], parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ @@ -197,51 +70,6 @@ def generate(self, model: str, stream: bool, messages: list[ErnieMessage], return self._handle_chat_stream_generate_response(resp) return self._handle_chat_generate_response(resp) - def _handle_error(self, code: int, msg: str): - error_map = { - 1: InternalServerError, - 2: InternalServerError, - 3: BadRequestError, - 4: RateLimitReachedError, - 6: InvalidAuthenticationError, - 13: InvalidAPIKeyError, - 14: InvalidAPIKeyError, - 15: InvalidAPIKeyError, - 17: RateLimitReachedError, - 18: RateLimitReachedError, - 19: RateLimitReachedError, - 100: InvalidAPIKeyError, - 111: InvalidAPIKeyError, - 200: InternalServerError, - 336000: InternalServerError, - 336001: BadRequestError, - 336002: BadRequestError, - 336003: BadRequestError, - 336004: InvalidAuthenticationError, - 336005: InvalidAPIKeyError, - 336006: BadRequestError, - 336007: BadRequestError, - 336008: BadRequestError, - 336100: InternalServerError, - 336101: BadRequestError, - 336102: BadRequestError, - 336103: BadRequestError, - 336104: BadRequestError, - 336105: BadRequestError, - 336200: InternalServerError, - 336303: BadRequestError, - 337006: BadRequestError - } - - if code in error_map: - raise error_map[code](msg) - else: - raise InternalServerError(f'Unknown error: {msg}') - - def _get_access_token(self) -> str: - token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) - return token.access_token - def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py deleted file mode 100644 index 67d76b4a291c06..00000000000000 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py +++ /dev/null @@ -1,17 +0,0 @@ -class InvalidAuthenticationError(Exception): - pass - -class InvalidAPIKeyError(Exception): - pass - -class RateLimitReachedError(Exception): - pass - -class InsufficientAccountBalance(Exception): - pass - -class InternalServerError(Exception): - pass - -class BadRequestError(Exception): - pass \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index d39d63deeeae7d..140606298cac37 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -11,24 +11,13 @@ UserPromptMessage, ) from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( - BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, -) +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken +from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage +from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure @@ -140,7 +129,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: api_key = credentials['api_key'] secret_key = credentials['secret_key'] try: - BaiduAccessToken._get_access_token(api_key, secret_key) + BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') @@ -254,22 +243,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ - return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], - InvokeAuthorizationError: [ - InvalidAuthenticationError, - InsufficientAccountBalance, - InvalidAPIKeyError, - ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] - } + return invoke_error_mapping() diff --git a/api/core/model_runtime/model_providers/wenxin/llm/yi_34b_chat.yaml b/api/core/model_runtime/model_providers/wenxin/llm/yi_34b_chat.yaml new file mode 100644 index 00000000000000..0b247fbd223dac --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/yi_34b_chat.yaml @@ -0,0 +1,30 @@ +model: yi_34b_chat +label: + en_US: yi_34b_chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.95 + - name: top_p + use_template: top_p + min: 0 + max: 1.0 + default: 0.7 + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 2 + max: 4096 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml new file mode 100644 index 00000000000000..74fadb7f9de60f --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml @@ -0,0 +1,9 @@ +model: bge-large-en +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml new file mode 100644 index 00000000000000..d4af27ec389a66 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml @@ -0,0 +1,9 @@ +model: bge-large-zh +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml new file mode 100644 index 00000000000000..eda48d965533e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml @@ -0,0 +1,9 @@ +model: embedding-v1 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml new file mode 100644 index 00000000000000..e28f253eb6b861 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml @@ -0,0 +1,9 @@ +model: tao-8k +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 1 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..10ac1a1861e29f --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -0,0 +1,184 @@ +import time +from abc import abstractmethod +from collections.abc import Mapping +from json import dumps +from typing import Any, Optional + +import numpy as np +from requests import Response, post + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + invoke_error_mapping, +) + + +class TextEmbedding: + @abstractmethod + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + raise NotImplementedError + + +class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + access_token = self._get_access_token() + url = f'{self.api_bases[model]}?access_token={access_token}' + body = self._build_embed_request_body(model, texts, user) + headers = { + 'Content-Type': 'application/json', + } + + resp = post(url, data=dumps(body), headers=headers) + if resp.status_code != 200: + raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + return self._handle_embed_response(model, resp) + + def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: + if len(texts) == 0: + raise BadRequestError('The number of texts should not be zero.') + body = { + 'input': texts, + 'user_id': user, + } + return body + + def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + data = response.json() + if 'error_code' in data: + code = data['error_code'] + msg = data['error_msg'] + # raise error + self._handle_error(code, msg) + + embeddings = [v['embedding'] for v in data['data']] + _usage = data['usage'] + tokens = _usage['prompt_tokens'] + total_tokens = _usage['total_tokens'] + + return embeddings, tokens, total_tokens + + +class WenxinTextEmbeddingModel(TextEmbeddingModel): + def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: + return WenxinTextEmbedding(api_key, secret_key) + + def _invoke(self, model: str, credentials: dict, texts: list[str], + user: Optional[str] = None) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + api_key = credentials['api_key'] + secret_key = credentials['secret_key'] + embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) + user = user if user else 'ErnieBotDefault' + + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + used_total_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + for i in _iter: + embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( + model, + inputs[i: i + max_chunks], + user) + used_tokens += _used_tokens + used_total_tokens += _total_used_tokens + batched_embeddings += embeddings_batch + + usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens) + return TextEmbeddingResult( + model=model, + embeddings=batched_embeddings, + usage=usage, + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + total_num_tokens = 0 + for text in texts: + total_num_tokens += self._get_num_tokens_by_gpt2(text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + api_key = credentials['api_key'] + secret_key = credentials['secret_key'] + try: + BaiduAccessToken.get_access_token(api_key, secret_key) + except Exception as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return invoke_error_mapping() + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=total_tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index b3a1f608249bbf..6a6b38e6a1b7ab 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -17,6 +17,7 @@ help: en_US: https://cloud.baidu.com/wenxin.html supported_model_types: - llm + - text-embedding configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py new file mode 100644 index 00000000000000..0fbd0f55ec9d90 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -0,0 +1,57 @@ +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ], + InvokeServerUnavailableError: [ + InternalServerError + ], + InvokeRateLimitError: [ + RateLimitReachedError + ], + InvokeAuthorizationError: [ + InvalidAuthenticationError, + InsufficientAccountBalance, + InvalidAPIKeyError, + ], + InvokeBadRequestError: [ + BadRequestError, + KeyError + ] + } + + +class InvalidAuthenticationError(Exception): + pass + +class InvalidAPIKeyError(Exception): + pass + +class RateLimitReachedError(Exception): + pass + +class InsufficientAccountBalance(Exception): + pass + +class InternalServerError(Exception): + pass + +class BadRequestError(Exception): + pass \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 988bb0ce4432df..4760e8f1185e8d 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -85,7 +85,8 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes tools=tools, stop=stop, stream=stream, user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) ) @@ -106,7 +107,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'completion_type' not in credentials: if 'chat' in extra_param.model_ability: @@ -396,7 +398,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode else: extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'chat' in extra_args.model_ability: @@ -464,6 +467,7 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM xinference_client = Client( base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_model = xinference_client.get_model(credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 4e7543fd996fd7..d809537479f40d 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -108,7 +108,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 9ee36213176ef7..62b77f22e59c87 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -52,7 +52,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 11f1e29cb39f81..3a8d704c25838c 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -110,14 +110,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: server_url = credentials['server_url'] model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + api_key = credentials.get('api_key') + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=server_url, + model_uid=model_uid, + api_key=api_key, + ) if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) + client = Client( + base_url=server_url, + api_key=api_key, + ) try: handle = client.get_model(model_uid=model_uid) diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index c106e38781cc79..bfa752df8cdb31 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,11 +1,7 @@ import concurrent.futures -from functools import reduce -from io import BytesIO from typing import Optional -from flask import Response -from pydub import AudioSegment -from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -19,6 +15,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper class XinferenceText2SpeechModel(TTSModel): @@ -26,7 +23,12 @@ class XinferenceText2SpeechModel(TTSModel): def __init__(self): # preset voices, need support custom voice self.model_voices = { - 'chattts': { + '__default': { + 'all': [ + {'name': 'Default', 'value': 'default'}, + ] + }, + 'ChatTTS': { 'all': [ {'name': 'Alloy', 'value': 'alloy'}, {'name': 'Echo', 'value': 'echo'}, @@ -36,7 +38,7 @@ def __init__(self): {'name': 'Shimmer', 'value': 'shimmer'}, ] }, - 'cosyvoice': { + 'CosyVoice': { 'zh-Hans': [ {'name': '中文男', 'value': '中文男'}, {'name': '中文女', 'value': '中文女'}, @@ -77,18 +79,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] + extra_param = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulAudioModelHandle): + if 'text-to-audio' not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + 'please check model type, the model you want to invoke is not a text-to-audio model') + + if extra_param.model_family and extra_param.model_family in self.model_voices: + credentials['audio_model_name'] = extra_param.model_family + else: + credentials['audio_model_name'] = '__default' - self._tts_invoke( + self._tts_invoke_streaming( model=model, credentials=credentials, content_text='Hello Dify!', @@ -110,7 +116,7 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s :param user: unique user id :return: text translated to audio file """ - return self._tts_invoke(model, credentials, content_text, voice) + return self._tts_invoke_streaming(model, credentials, content_text, voice) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ @@ -161,13 +167,15 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + audio_model_name = credentials.get('audio_model_name', '__default') for key, voices in self.model_voices.items(): - if key in model.lower(): - if language in voices: + if key in audio_model_name: + if language and language in voices: return voices[language] elif 'all' in voices: return voices['all'] - return [] + + return self.model_voices['__default']['all'] def _get_model_default_voice(self, model: str, credentials: dict) -> any: return "" @@ -181,60 +189,59 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, + voice: str) -> any: """ - _tts_invoke text2speech model + _tts_invoke_streaming text2speech model :param model: model name :param credentials: model credentials - :param voice: model timbre :param content_text: text content to be translated + :param voice: model timbre :return: text translated to audio file """ if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - word_limit = self._get_model_word_limit(model, credentials) - audio_type = self._get_model_audio_type(model, credentials) - handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) - try: - sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) - audio_bytes_list = [] + api_key = credentials.get('api_key') + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + handle = RESTfulAudioModelHandle( + credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + ) - with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor: + model_support_voice = [x.get("value") for x in + self.get_tts_model_voices(model=model, credentials=credentials)] + if not voice or voice not in model_support_voice: + voice = self._get_model_default_voice(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) futures = [executor.submit( - handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False) - for sentence in sentences] - for future in futures: - try: - if future.result(): - audio_bytes_list.append(future.result()) - except Exception as ex: - raise InvokeBadRequestError(str(ex)) - - if len(audio_bytes_list) > 0: - audio_segments = [AudioSegment.from_file( - BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] - combined_segment = reduce(lambda x, y: x + y, audio_segments) - buffer: BytesIO = BytesIO() - combined_segment.export(buffer, format=audio_type) - buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + handle.speech, + input=sentences[i], + voice=voice, + response_format="mp3", + speed=1.0, + stream=False + ) + for i in range(len(sentences))] + + for index, future in enumerate(futures): + response = future.result() + for i in range(0, len(response), 1024): + yield response[i:i + 1024] + else: + response = handle.speech( + input=content_text.strip(), + voice=voice, + response_format="mp3", + speed=1.0, + stream=False + ) + + for i in range(0, len(response), 1024): + yield response[i:i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) - - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: - """ - _tts_invoke_streaming text2speech model - - Attention: stream api may return error [Parallel generation is not supported by ggml] - - :param model: model name - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :return: text translated to audio file - """ - pass diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 9a3fc9b193c475..75161ad376c419 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,5 +1,6 @@ from threading import Lock from time import time +from typing import Optional from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout @@ -15,9 +16,11 @@ class XinferenceModelExtraParameter: context_length: int = 2048 support_function_call: bool = False support_vision: bool = False + model_family: Optional[str] def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None: + support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, + model_family: Optional[str]) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -25,19 +28,20 @@ def __init__(self, model_format: str, model_handle_type: str, model_ability: lis self.support_vision = support_vision self.max_tokens = max_tokens self.context_length = context_length + self.model_family = model_family cache = {} cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: cache[model_uid] = { 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) } return cache[model_uid]['value'] @@ -52,7 +56,7 @@ def _clean_cache() -> None: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ @@ -66,9 +70,10 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen session = Session() session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3)) + headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') if response.status_code != 200: @@ -78,9 +83,16 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen model_format = response_json.get('model_format', 'ggmlv3') model_ability = response_json.get('model_ability', []) + model_family = response_json.get('model_family', None) if response_json.get('model_type') == 'embedding': model_handle_type = 'embedding' + elif response_json.get('model_type') == 'audio': + model_handle_type = 'audio' + if model_family and model_family in ['ChatTTS', 'CosyVoice']: + model_ability.append('text-to-audio') + else: + model_ability.append('audio-to-text') elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: model_handle_type = 'chatglm' elif 'generate' in model_ability: @@ -88,7 +100,7 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen elif 'chat' in model_ability: model_handle_type = 'chat' else: - raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') + raise NotImplementedError('xinference model handle type is not supported') support_function_call = 'tools' in model_ability support_vision = 'vision' in model_ability @@ -103,5 +115,6 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen support_function_call=support_function_call, support_vision=support_vision, max_tokens=max_tokens, - context_length=context_length - ) \ No newline at end of file + context_length=context_length, + model_family=model_family + ) diff --git a/api/core/model_runtime/model_providers/zhinao/__init__.py b/api/core/model_runtime/model_providers/zhinao/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/zhinao/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/zhinao/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..b22b8694419bc7 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/_assets/icon_l_en.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/api/core/model_runtime/model_providers/zhinao/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/zhinao/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..8fe72b7d0928e6 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/_assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo-responsibility-8k.yaml b/api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo-responsibility-8k.yaml new file mode 100644 index 00000000000000..f420df0001b3a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo-responsibility-8k.yaml @@ -0,0 +1,36 @@ +model: 360gpt-turbo-responsibility-8k +label: + zh_Hans: 360gpt-turbo-responsibility-8k + en_US: 360gpt-turbo-responsibility-8k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 8192 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo.yaml b/api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo.yaml new file mode 100644 index 00000000000000..a2658fbe4f5c0e --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo.yaml @@ -0,0 +1,36 @@ +model: 360gpt-turbo +label: + zh_Hans: 360gpt-turbo + en_US: 360gpt-turbo +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 2048 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/zhinao/llm/360gpt2-pro.yaml b/api/core/model_runtime/model_providers/zhinao/llm/360gpt2-pro.yaml new file mode 100644 index 00000000000000..00c81eb1daaffb --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/llm/360gpt2-pro.yaml @@ -0,0 +1,36 @@ +model: 360gpt2-pro +label: + zh_Hans: 360gpt2-pro + en_US: 360gpt2-pro +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 2048 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/zhinao/llm/__init__.py b/api/core/model_runtime/model_providers/zhinao/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/zhinao/llm/_position.yaml b/api/core/model_runtime/model_providers/zhinao/llm/_position.yaml new file mode 100644 index 00000000000000..ab8dbf51821c0a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/llm/_position.yaml @@ -0,0 +1,3 @@ +- 360gpt2-pro +- 360gpt-turbo +- 360gpt-turbo-responsibility-8k diff --git a/api/core/model_runtime/model_providers/zhinao/llm/llm.py b/api/core/model_runtime/model_providers/zhinao/llm/llm.py new file mode 100644 index 00000000000000..6930a5ed0134b0 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/llm/llm.py @@ -0,0 +1,25 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['endpoint_url'] = 'https://api.360.cn/v1' diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.py b/api/core/model_runtime/model_providers/zhinao/zhinao.py new file mode 100644 index 00000000000000..44b36c9f51edd7 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.py @@ -0,0 +1,32 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class ZhinaoProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `360gpt-turbo` model for validate, + # no matter what model you pass in, text completion model or chat model + model_instance.validate_credentials( + model='360gpt-turbo', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.yaml b/api/core/model_runtime/model_providers/zhinao/zhinao.yaml new file mode 100644 index 00000000000000..c5cb142c47d4c3 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.yaml @@ -0,0 +1,32 @@ +provider: zhinao +label: + en_US: 360 AI + zh_Hans: 360 智脑 +description: + en_US: Models provided by 360 AI. + zh_Hans: 360 智脑提供的模型。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#e3f0ff" +help: + title: + en_US: Get your API Key from 360 AI. + zh_Hans: 从360 智脑获取 API Key + url: + en_US: https://ai.360.com/platform/keys +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml index 3968e8f26822ab..8391278e4f1ea3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.1' + output: '0.1' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml index ae2d5e5d533d1a..7caebd3e4b6aa8 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml index c0038a1ab223df..dc123913deb8b5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.01' + output: '0.01' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml index 650f9faee6ea0d..0e3c001f06b0f4 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0' + output: '0' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml index 5bdb4428403908..b0f95c0a68e555 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index 6b5bcc5bcf4468..271eecf199c476 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.1' + output: '0.1' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml new file mode 100644 index 00000000000000..150e07b60af979 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml @@ -0,0 +1,38 @@ +model: glm-4-long +label: + en_US: glm-4-long +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 10240 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + min: 0.0 + max: 1.0 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml new file mode 100644 index 00000000000000..237a951cd5a14e --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml @@ -0,0 +1,44 @@ +model: glm-4-plus +label: + en_US: glm-4-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: incremental + label: + zh_Hans: 增量返回 + en_US: Incremental + type: boolean + help: + zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 + en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. + required: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index ddea331c8e46f3..c7a4093d7aa7a2 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -34,4 +34,9 @@ parameter_rules: use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 1024 +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml new file mode 100644 index 00000000000000..a7aee5b4ca2363 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -0,0 +1,42 @@ +model: glm-4v-plus +label: + en_US: glm-4v-plus +model_type: llm +model_properties: + mode: chat +features: + - vision +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: incremental + label: + zh_Hans: 增量返回 + en_US: Incremental + type: boolean + help: + zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 + en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. + required: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 +pricing: + input: '0.01' + output: '0.01' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index ff971964a8603e..b2cdc7ad7a4644 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -153,7 +153,8 @@ def _generate(self, model: str, credentials_kwargs: dict, :return: full response or stream response chunk generator result """ extra_model_kwargs = {} - if stop: + # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" + if stop and model!= 'glm-4v-plus': extra_model_kwargs['stop'] = stop client = ZhipuAI( @@ -174,7 +175,7 @@ def _generate(self, model: str, credentials_kwargs: dict, if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model != 'glm-4v': + if model not in ('glm-4v', 'glm-4v-plus'): # not support list message continue # get image and @@ -207,7 +208,7 @@ def _generate(self, model: str, credentials_kwargs: dict, else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v': + if model == 'glm-4v' or model == 'glm-4v-plus': params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = { @@ -304,7 +305,7 @@ def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMe return params - def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]: + def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: if isinstance(prompt_message, str): return [{'type': 'text', 'text': prompt_message}] @@ -444,6 +445,7 @@ def _handle_generate_stream_response(self, model: str, delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, + finish_reason=delta.finish_reason ) ) diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-2.yaml b/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-2.yaml index faf0f818c4a782..f1b8b356028e25 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-2.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-2.yaml @@ -1,4 +1,8 @@ model: embedding-2 model_type: text-embedding model_properties: - context_size: 512 + context_size: 8192 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-3.yaml b/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-3.yaml new file mode 100644 index 00000000000000..5c55c911c4bdd1 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/embedding-3.yaml @@ -0,0 +1,8 @@ +model: embedding-3 +model_type: text-embedding +model_properties: + context_size: 8192 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index c5dd88fb2458b1..8157b300b1f6c7 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -4,7 +4,8 @@ from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationException from core.moderation.factory import ModerationFactory -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time logger = logging.getLogger(__name__) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index db7e0806ee8d74..a1443f0691233b 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -105,4 +106,15 @@ class GenerateNameTraceInfo(BaseTraceInfo): 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, 'ToolTraceInfo': ToolTraceInfo, 'GenerateNameTraceInfo': GenerateNameTraceInfo, -} \ No newline at end of file +} + + +class TraceTaskName(str, Enum): + CONVERSATION_TRACE = 'conversation' + WORKFLOW_TRACE = 'workflow' + MESSAGE_TRACE = 'message' + MODERATION_TRACE = 'moderation' + SUGGESTED_QUESTION_TRACE = 'suggested_question' + DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' + TOOL_TRACE = 'tool' + GENERATE_NAME_TRACE = 'generate_conversation_name' diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index b90c05f4cbc605..af7661f0afc9c8 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -50,10 +50,11 @@ class LangfuseTrace(BaseModel): """ Langfuse trace model """ + id: Optional[str] = Field( default=None, description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems " - "or when creating a distributed trace. Traces are upserted on id.", + "or when creating a distributed trace. Traces are upserted on id.", ) name: Optional[str] = Field( default=None, @@ -68,7 +69,7 @@ class LangfuseTrace(BaseModel): metadata: Optional[dict[str, Any]] = Field( default=None, description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated " - "via the API.", + "via the API.", ) user_id: Optional[str] = Field( default=None, @@ -81,22 +82,22 @@ class LangfuseTrace(BaseModel): version: Optional[str] = Field( default=None, description="The version of the trace type. Used to understand how changes to the trace type affect metrics. " - "Useful in debugging.", + "Useful in debugging.", ) release: Optional[str] = Field( default=None, description="The release identifier of the current deployment. Used to understand how changes of different " - "deployments affect metrics. Useful in debugging.", + "deployments affect metrics. Useful in debugging.", ) tags: Optional[list[str]] = Field( default=None, description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET " - "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", + "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", ) public: Optional[bool] = Field( default=None, description="You can make a trace public to share it via a public link. This allows others to view the trace " - "without needing to log in or be members of your Langfuse project.", + "without needing to log in or be members of your Langfuse project.", ) @field_validator("input", "output") @@ -109,6 +110,7 @@ class LangfuseSpan(BaseModel): """ Langfuse span model """ + id: Optional[str] = Field( default=None, description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.", @@ -140,17 +142,17 @@ class LangfuseSpan(BaseModel): metadata: Optional[dict[str, Any]] = Field( default=None, description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated " - "via the API.", + "via the API.", ) level: Optional[str] = Field( default=None, description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of " - "traces with elevated error levels and for highlighting in the UI.", + "traces with elevated error levels and for highlighting in the UI.", ) status_message: Optional[str] = Field( default=None, description="The status message of the span. Additional field for context of the event. E.g. the error " - "message of an error event.", + "message of an error event.", ) input: Optional[Union[str, dict[str, Any], list, None]] = Field( default=None, description="The input of the span. Can be any JSON object." @@ -161,7 +163,7 @@ class LangfuseSpan(BaseModel): version: Optional[str] = Field( default=None, description="The version of the span type. Used to understand how changes to the span type affect metrics. " - "Useful in debugging.", + "Useful in debugging.", ) parent_observation_id: Optional[str] = Field( default=None, @@ -185,10 +187,9 @@ class UnitEnum(str, Enum): class GenerationUsage(BaseModel): promptTokens: Optional[int] = None completionTokens: Optional[int] = None - totalTokens: Optional[int] = None + total: Optional[int] = None input: Optional[int] = None output: Optional[int] = None - total: Optional[int] = None unit: Optional[UnitEnum] = None inputCost: Optional[float] = None outputCost: Optional[float] = None @@ -224,15 +225,13 @@ class LangfuseGeneration(BaseModel): completion_start_time: Optional[datetime | str] = Field( default=None, description="The time at which the completion started (streaming). Set it to get latency analytics broken " - "down into time until completion started and completion duration.", + "down into time until completion started and completion duration.", ) end_time: Optional[datetime | str] = Field( default=None, description="The time at which the generation ended. Automatically set by generation.end().", ) - model: Optional[str] = Field( - default=None, description="The name of the model used for the generation." - ) + model: Optional[str] = Field(default=None, description="The name of the model used for the generation.") model_parameters: Optional[dict[str, Any]] = Field( default=None, description="The parameters of the model used for the generation; can be any key-value pairs.", @@ -248,27 +247,27 @@ class LangfuseGeneration(BaseModel): usage: Optional[GenerationUsage] = Field( default=None, description="The usage object supports the OpenAi structure with tokens and a more generic version with " - "detailed costs and units.", + "detailed costs and units.", ) metadata: Optional[dict[str, Any]] = Field( default=None, description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being " - "updated via the API.", + "updated via the API.", ) level: Optional[LevelEnum] = Field( default=None, description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering " - "of traces with elevated error levels and for highlighting in the UI.", + "of traces with elevated error levels and for highlighting in the UI.", ) status_message: Optional[str] = Field( default=None, description="The status message of the generation. Additional field for context of the event. E.g. the error " - "message of an error event.", + "message of an error event.", ) version: Optional[str] = Field( default=None, description="The version of the generation type. Used to understand how changes to the span type affect " - "metrics. Useful in debugging.", + "metrics. Useful in debugging.", ) model_config = ConfigDict(protected_namespaces=()) @@ -277,4 +276,3 @@ class LangfuseGeneration(BaseModel): def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) - diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index c520fe2aa9c089..a21c67ed503da6 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -16,6 +16,7 @@ ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, WorkflowTraceInfo, ) from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( @@ -68,9 +69,9 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): user_id = trace_info.metadata.get("user_id") if trace_info.message_id: trace_id = trace_info.message_id - name = f"message_{trace_info.message_id}" + name = TraceTaskName.MESSAGE_TRACE.value trace_data = LangfuseTrace( - id=trace_info.message_id, + id=trace_id, user_id=user_id, name=name, input=trace_info.workflow_run_inputs, @@ -78,11 +79,13 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): metadata=trace_info.metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], + created_at=trace_info.start_time, + updated_at=trace_info.end_time, ) self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( - id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, - name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}", + id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), + name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, trace_id=trace_id, @@ -97,7 +100,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_data = LangfuseTrace( id=trace_id, user_id=user_id, - name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}", + name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, metadata=trace_info.metadata, @@ -134,14 +137,12 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): node_type = node_execution.node_type status = node_execution.status if node_type == "llm": - inputs = json.loads(node_execution.process_data).get( - "prompts", {} - ) if node_execution.process_data else {} + inputs = ( + json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} + ) else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = ( - json.loads(node_execution.outputs) if node_execution.outputs else {} - ) + outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} created_at = node_execution.created_at if node_execution.created_at else datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -163,28 +164,30 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): if trace_info.message_id: span_data = LangfuseSpan( id=node_execution_id, - name=f"{node_name}_{node_execution_id}", + name=node_type, input=inputs, output=outputs, trace_id=trace_id, start_time=created_at, end_time=finished_at, metadata=metadata, - level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error if trace_info.error else "", - parent_observation_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + parent_observation_id=( + trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id + ), ) else: span_data = LangfuseSpan( id=node_execution_id, - name=f"{node_name}_{node_execution_id}", + name=node_type, input=inputs, output=outputs, trace_id=trace_id, start_time=created_at, end_time=finished_at, metadata=metadata, - level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error if trace_info.error else "", ) @@ -195,11 +198,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): total_token = metadata.get("total_tokens", 0) # add generation generation_usage = GenerationUsage( - totalTokens=total_token, + total=total_token, ) node_generation_data = LangfuseGeneration( - name=f"generation_{node_execution_id}", + name="llm", trace_id=trace_id, parent_observation_id=node_execution_id, start_time=created_at, @@ -207,16 +210,14 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): input=inputs, output=outputs, metadata=metadata, - level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error if trace_info.error else "", usage=generation_usage, ) self.add_generation(langfuse_generation_data=node_generation_data) - def message_trace( - self, trace_info: MessageTraceInfo, **kwargs - ): + def message_trace(self, trace_info: MessageTraceInfo, **kwargs): # get message file data file_list = trace_info.file_list metadata = trace_info.metadata @@ -225,9 +226,9 @@ def message_trace( user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser = db.session.query(EndUser).filter( - EndUser.id == message_data.from_end_user_id - ).first() + end_user_data: EndUser = ( + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + ) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id @@ -235,7 +236,7 @@ def message_trace( trace_data = LangfuseTrace( id=message_id, user_id=user_id, - name=f"message_{message_id}", + name=TraceTaskName.MESSAGE_TRACE.value, input={ "message": trace_info.inputs, "files": file_list, @@ -258,7 +259,6 @@ def message_trace( # start add span generation_usage = GenerationUsage( - totalTokens=trace_info.total_tokens, input=trace_info.message_tokens, output=trace_info.answer_tokens, total=trace_info.total_tokens, @@ -267,7 +267,7 @@ def message_trace( ) langfuse_generation_data = LangfuseGeneration( - name=f"generation_{message_id}", + name="llm", trace_id=message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -275,7 +275,7 @@ def message_trace( input=trace_info.inputs, output=message_data.answer, metadata=metadata, - level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), status_message=message_data.error if message_data.error else "", usage=generation_usage, ) @@ -284,7 +284,7 @@ def message_trace( def moderation_trace(self, trace_info: ModerationTraceInfo): span_data = LangfuseSpan( - name="moderation", + name=TraceTaskName.MODERATION_TRACE.value, input=trace_info.inputs, output={ "action": trace_info.action, @@ -303,22 +303,21 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data generation_usage = GenerationUsage( - totalTokens=len(str(trace_info.suggested_question)), + total=len(str(trace_info.suggested_question)), input=len(trace_info.inputs), output=len(trace_info.suggested_question), - total=len(trace_info.suggested_question), unit=UnitEnum.CHARACTERS, ) generation_data = LangfuseGeneration( - name="suggested_question", + name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, input=trace_info.inputs, output=str(trace_info.suggested_question), trace_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, - level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), status_message=message_data.error if message_data.error else "", usage=generation_usage, ) @@ -327,7 +326,7 @@ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): dataset_retrieval_span_data = LangfuseSpan( - name="dataset_retrieval", + name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, input=trace_info.inputs, output={"documents": trace_info.documents}, trace_id=trace_info.message_id, @@ -347,7 +346,7 @@ def tool_trace(self, trace_info: ToolTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, - level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR), status_message=trace_info.error, ) @@ -355,7 +354,7 @@ def tool_trace(self, trace_info: ToolTraceInfo): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_generation_trace_data = LangfuseTrace( - name="generate_name", + name=TraceTaskName.GENERATE_NAME_TRACE.value, input=trace_info.inputs, output=trace_info.outputs, user_id=trace_info.tenant_id, @@ -366,7 +365,7 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): self.add_trace(langfuse_trace_data=name_generation_trace_data) name_generation_span_data = LangfuseSpan( - name="generate_name", + name=TraceTaskName.GENERATE_NAME_TRACE.value, input=trace_info.inputs, output=trace_info.outputs, trace_id=trace_info.conversation_id, @@ -377,9 +376,7 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): self.add_span(langfuse_span_data=name_generation_span_data) def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): - format_trace_data = ( - filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} - ) + format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} try: self.langfuse_client.trace(**format_trace_data) logger.debug("LangFuse Trace created successfully") @@ -387,9 +384,7 @@ def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): raise ValueError(f"LangFuse Failed to create trace: {str(e)}") def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): - format_span_data = ( - filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} - ) + format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} try: self.langfuse_client.span(**format_span_data) logger.debug("LangFuse Span created successfully") @@ -397,19 +392,13 @@ def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): raise ValueError(f"LangFuse Failed to create span: {str(e)}") def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None): - format_span_data = ( - filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} - ) + format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} span.end(**format_span_data) - def add_generation( - self, langfuse_generation_data: Optional[LangfuseGeneration] = None - ): + def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None): format_generation_data = ( - filter_none_values(langfuse_generation_data.model_dump()) - if langfuse_generation_data - else {} + filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) try: self.langfuse_client.generation(**format_generation_data) @@ -417,13 +406,9 @@ def add_generation( except Exception as e: raise ValueError(f"LangFuse Failed to create generation: {str(e)}") - def update_generation( - self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None - ): + def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None): format_generation_data = ( - filter_none_values(langfuse_generation_data.model_dump()) - if langfuse_generation_data - else {} + filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) generation.end(**format_generation_data) @@ -434,3 +419,11 @@ def api_check(self): except Exception as e: logger.debug(f"LangFuse API check failed: {str(e)}") raise ValueError(f"LangFuse API check failed: {str(e)}") + + def get_project_key(self): + try: + projects = self.langfuse_client.client.projects.get() + return projects.data[0].id + except Exception as e: + logger.debug(f"LangFuse get project key failed: {str(e)}") + raise ValueError(f"LangFuse get project key failed: {str(e)}") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 0ce91db335cd94..fde8a06c612dd9 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -15,6 +15,7 @@ ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, WorkflowTraceInfo, ) from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( @@ -39,9 +40,7 @@ def __init__( self.langsmith_key = langsmith_config.api_key self.project_name = langsmith_config.project self.project_id = None - self.langsmith_client = Client( - api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint - ) + self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") def trace(self, trace_info: BaseTraceInfo): @@ -64,7 +63,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): if trace_info.message_id: message_run = LangSmithRunModel( id=trace_info.message_id, - name=f"message_{trace_info.message_id}", + name=TraceTaskName.MESSAGE_TRACE.value, inputs=trace_info.workflow_run_inputs, outputs=trace_info.workflow_run_outputs, run_type=LangSmithRunType.chain, @@ -73,8 +72,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): extra={ "metadata": trace_info.metadata, }, - tags=["message"], - error=trace_info.error + tags=["message", "workflow"], + error=trace_info.error, ) self.add_run(message_run) @@ -82,7 +81,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, - name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}", + name=TraceTaskName.WORKFLOW_TRACE.value, inputs=trace_info.workflow_run_inputs, run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, @@ -126,22 +125,18 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): node_type = node_execution.node_type status = node_execution.status if node_type == "llm": - inputs = json.loads(node_execution.process_data).get( - "prompts", {} - ) if node_execution.process_data else {} + inputs = ( + json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} + ) else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = ( - json.loads(node_execution.outputs) if node_execution.outputs else {} - ) + outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} created_at = node_execution.created_at if node_execution.created_at else datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = ( - json.loads(node_execution.execution_metadata) - if node_execution.execution_metadata - else {} + json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) @@ -168,7 +163,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): langsmith_run = LangSmithRunModel( total_tokens=node_total_tokens, - name=f"{node_name}_{node_execution_id}", + name=node_type, inputs=inputs, run_type=run_type, start_time=created_at, @@ -178,7 +173,9 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): extra={ "metadata": metadata, }, - parent_run_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + parent_run_id=trace_info.workflow_app_log_id + if trace_info.workflow_app_log_id + else trace_info.workflow_run_id, tags=["node_execution"], ) @@ -198,9 +195,9 @@ def message_trace(self, trace_info: MessageTraceInfo): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = db.session.query(EndUser).filter( - EndUser.id == message_data.from_end_user_id - ).first() + end_user_data: EndUser = ( + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + ) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id @@ -210,7 +207,7 @@ def message_trace(self, trace_info: MessageTraceInfo): output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, id=message_id, - name=f"message_{message_id}", + name=TraceTaskName.MESSAGE_TRACE.value, inputs=trace_info.inputs, run_type=LangSmithRunType.chain, start_time=trace_info.start_time, @@ -230,7 +227,7 @@ def message_trace(self, trace_info: MessageTraceInfo): input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, - name=f"llm_{message_id}", + name="llm", inputs=trace_info.inputs, run_type=LangSmithRunType.llm, start_time=trace_info.start_time, @@ -248,7 +245,7 @@ def message_trace(self, trace_info: MessageTraceInfo): def moderation_trace(self, trace_info: ModerationTraceInfo): langsmith_run = LangSmithRunModel( - name="moderation", + name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, outputs={ "action": trace_info.action, @@ -271,7 +268,7 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data suggested_question_run = LangSmithRunModel( - name="suggested_question", + name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, @@ -288,7 +285,7 @@ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): dataset_retrieval_run = LangSmithRunModel( - name="dataset_retrieval", + name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, @@ -323,7 +320,7 @@ def tool_trace(self, trace_info: ToolTraceInfo): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_run = LangSmithRunModel( - name="generate_name", + name=TraceTaskName.GENERATE_NAME_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 61279e3f5f29c5..1416d6bd2d6685 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -5,7 +5,6 @@ import threading import time from datetime import timedelta -from enum import Enum from typing import Any, Optional, Union from uuid import UUID @@ -24,6 +23,7 @@ ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, WorkflowTraceInfo, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -38,7 +38,7 @@ TracingProviderEnum.LANGFUSE.value: { 'config_class': LangfuseConfig, 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host'], + 'other_keys': ['host', 'project_key'], 'trace_instance': LangFuseDataTrace }, TracingProviderEnum.LANGSMITH.value: { @@ -123,7 +123,6 @@ def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: for key in other_keys: new_config[key] = decrypt_tracing_config.get(key, "") - return config_class(**new_config).model_dump() @classmethod @@ -252,16 +251,18 @@ def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() - -class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation_trace' - WORKFLOW_TRACE = 'workflow_trace' - MESSAGE_TRACE = 'message_trace' - MODERATION_TRACE = 'moderation_trace' - SUGGESTED_QUESTION_TRACE = 'suggested_question_trace' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval_trace' - TOOL_TRACE = 'tool_trace' - GENERATE_NAME_TRACE = 'generate_name_trace' + @staticmethod + def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): + """ + get trace config is project key + :param tracing_config: tracing config + :param tracing_provider: tracing provider + :return: + """ + config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ + provider_config_map[tracing_provider]['trace_instance'] + tracing_config = config_type(**tracing_config) + return trace_instance(tracing_config).get_project_key() class TraceTask: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 452b270348b2ff..fd7ed0181be2f2 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,11 +1,10 @@ import enum import json import os -from typing import Optional +from typing import TYPE_CHECKING, Optional from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -18,6 +17,9 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class ModelMode(enum.Enum): COMPLETION = 'completion' @@ -50,7 +52,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> \ @@ -163,7 +165,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -206,7 +208,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -255,7 +257,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] for file in files: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c0b3746e188f5e..67eee2c2944e7a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -18,12 +19,9 @@ ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.helper.position_helper import is_filtered from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormType, - ProviderEntity, -) +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db @@ -45,6 +43,7 @@ class ProviderManager: """ ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ + def __init__(self) -> None: self.decoding_rsa_key = None self.decoding_cipher_rsa = None @@ -117,6 +116,16 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: + + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + continue + provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) @@ -271,6 +280,24 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D ) ) + def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Get names of first model and its provider + + :param tenant_id: workspace id + :param model_type: model type + :return: provider name, model name + """ + provider_configurations = self.get_configurations(tenant_id) + + # get available models from provider_configurations + all_models = provider_configurations.get_models( + model_type=model_type, + only_active=False + ) + + return all_models[0].provider.provider, all_models[0].model + def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ -> TenantDefaultModel: """ @@ -323,7 +350,8 @@ def update_default_model_record(self, tenant_id: str, model_type: ModelType, pro return default_model - def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: + @staticmethod + def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: """ Get all provider records of the workspace. @@ -342,7 +370,8 @@ def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: return provider_name_to_provider_records_dict - def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]: + @staticmethod + def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: """ Get all provider model records of the workspace. @@ -362,7 +391,8 @@ def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderMod return provider_name_to_provider_model_records_dict - def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: + @staticmethod + def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: """ Get All preferred provider types of the workspace. @@ -381,7 +411,8 @@ def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, Tenant return provider_name_to_preferred_provider_type_records_dict - def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + @staticmethod + def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: """ Get All provider model settings of the workspace. @@ -400,7 +431,8 @@ def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[Pro return provider_name_to_provider_model_settings_dict - def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + @staticmethod + def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: """ Get All provider load balancing configs of the workspace. @@ -431,7 +463,8 @@ def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, return provider_name_to_provider_load_balancing_model_configs_dict - def _init_trial_provider_records(self, tenant_id: str, + @staticmethod + def _init_trial_provider_records(tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -764,7 +797,8 @@ def _to_system_configuration(self, credentials=current_using_credentials ) - def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: + @staticmethod + def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: """ Choice current using quota type. paid quotas > provider free quotas > hosting trial quotas @@ -791,7 +825,8 @@ def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfi raise ValueError('No quota type available') - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + @staticmethod + def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. @@ -811,7 +846,7 @@ def _to_model_settings(self, provider_entity: ProviderEntity, -> list[ModelSettings]: """ Convert to model settings. - + :param provider_entity: provider entity :param provider_model_settings: provider model settings include enabled, load balancing enabled :param load_balancing_model_configs: load balancing model configs :return: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index abbf4a35a4628a..3932e90042c59c 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -28,7 +28,7 @@ class RetrievalService: @classmethod def retrieve(cls, retrival_method: str, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None, + reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', weights: Optional[dict] = None): dataset = db.session.query(Dataset).filter( Dataset.id == dataset_id @@ -36,10 +36,6 @@ def retrieve(cls, retrival_method: str, dataset_id: str, query: str, if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] - keyword_search_documents = [] - embedding_search_documents = [] - full_text_search_documents = [] - hybrid_search_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 442d71293f632d..b78e2a59b1eb6f 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -65,8 +65,15 @@ def __init__(self, collection_name: str, config: AnalyticdbConfig): AnalyticdbVector._init = True def _initialize(self) -> None: - self._initialize_vector_database() - self._create_namespace_if_not_exists() + cache_key = f"vector_indexing_{self.config.instance_id}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}" + if redis_client.get(collection_exist_cache_key): + return + self._initialize_vector_database() + self._create_namespace_if_not_exists() + redis_client.set(collection_exist_cache_key, 1, ex=3600) def _initialize_vector_database(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models @@ -285,9 +292,11 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: documents = [] for match in response.body.matches.match: if match.score > score_threshold: + metadata = json.loads(match.metadata.get("metadata_")) doc = Document( page_content=match.metadata.get("page_content"), - metadata=json.loads(match.metadata.get("metadata_")), + vector=match.metadata.get("vector"), + metadata=metadata, ) documents.append(doc) return documents @@ -320,7 +329,23 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) ) - # TODO handle optional params + # handle optional params + if dify_config.ANALYTICDB_KEY_ID is None: + raise ValueError("ANALYTICDB_KEY_ID should not be None") + if dify_config.ANALYTICDB_KEY_SECRET is None: + raise ValueError("ANALYTICDB_KEY_SECRET should not be None") + if dify_config.ANALYTICDB_REGION_ID is None: + raise ValueError("ANALYTICDB_REGION_ID should not be None") + if dify_config.ANALYTICDB_INSTANCE_ID is None: + raise ValueError("ANALYTICDB_INSTANCE_ID should not be None") + if dify_config.ANALYTICDB_ACCOUNT is None: + raise ValueError("ANALYTICDB_ACCOUNT should not be None") + if dify_config.ANALYTICDB_PASSWORD is None: + raise ValueError("ANALYTICDB_PASSWORD should not be None") + if dify_config.ANALYTICDB_NAMESPACE is None: + raise ValueError("ANALYTICDB_NAMESPACE should not be None") + if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None: + raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None") return AnalyticdbVector( collection_name, AnalyticdbConfig( diff --git a/api/core/rag/datasource/vdb/elasticsearch/__init__.py b/api/core/rag/datasource/vdb/elasticsearch/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py new file mode 100644 index 00000000000000..233539756fa84f --- /dev/null +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -0,0 +1,218 @@ +import json +import logging +from typing import Any, Optional +from urllib.parse import urlparse + +import requests +from elasticsearch import Elasticsearch +from flask import current_app +from pydantic import BaseModel, model_validator + +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class ElasticSearchConfig(BaseModel): + host: str + port: int + username: str + password: str + + @model_validator(mode='before') + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config HOST is required") + if not values['port']: + raise ValueError("config PORT is required") + if not values['username']: + raise ValueError("config USERNAME is required") + if not values['password']: + raise ValueError("config PASSWORD is required") + return values + + +class ElasticSearchVector(BaseVector): + def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): + super().__init__(index_name.lower()) + self._client = self._init_client(config) + self._version = self._get_version() + self._check_version() + self._attributes = attributes + + def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: + try: + parsed_url = urlparse(config.host) + if parsed_url.scheme in ['http', 'https']: + hosts = f'{config.host}:{config.port}' + else: + hosts = f'http://{config.host}:{config.port}' + client = Elasticsearch( + hosts=hosts, + basic_auth=(config.username, config.password), + request_timeout=100000, + retry_on_timeout=True, + max_retries=10000, + ) + except requests.exceptions.ConnectionError: + raise ConnectionError("Vector database connection error") + + return client + + def _get_version(self) -> str: + info = self._client.info() + return info['version']['number'] + + def _check_version(self): + if self._version < '8.0.0': + raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + + def get_type(self) -> str: + return 'elasticsearch' + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + for i in range(len(documents)): + self._client.index(index=self._collection_name, + id=uuids[i], + document={ + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] if embeddings[i] else None, + Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} + }) + self._client.indices.refresh(index=self._collection_name) + return uuids + + def text_exists(self, id: str) -> bool: + return self._client.exists(index=self._collection_name, id=id).__bool__() + + def delete_by_ids(self, ids: list[str]) -> None: + for id in ids: + self._client.delete(index=self._collection_name, id=id) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + query_str = { + 'query': { + 'match': { + f'metadata.{key}': f'{value}' + } + } + } + results = self._client.search(index=self._collection_name, body=query_str) + ids = [hit['_id'] for hit in results['hits']['hits']] + if ids: + self.delete_by_ids(ids) + + def delete(self) -> None: + self._client.indices.delete(index=self._collection_name) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 10) + knn = { + "field": Field.VECTOR.value, + "query_vector": query_vector, + "k": top_k + } + + results = self._client.search(index=self._collection_name, knn=knn, size=top_k) + + docs_and_scores = [] + for hit in results['hits']['hits']: + docs_and_scores.append( + (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], + vector=hit['_source'][Field.VECTOR.value], + metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) + + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + if score > score_threshold: + doc.metadata['score'] = score + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + query_str = { + "match": { + Field.CONTENT_KEY.value: query + } + } + results = self._client.search(index=self._collection_name, query=query_str) + docs = [] + for hit in results['hits']['hits']: + docs.append(Document( + page_content=hit['_source'][Field.CONTENT_KEY.value], + vector=hit['_source'][Field.VECTOR.value], + metadata=hit['_source'][Field.METADATA_KEY.value], + )) + + return docs + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + metadatas = [d.metadata for d in texts] + self.create_collection(embeddings, metadatas) + self.add_texts(texts, embeddings, **kwargs) + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ): + lock_name = f'vector_indexing_lock_{self._collection_name}' + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + mappings = { + "properties": { + Field.CONTENT_KEY.value: {"type": "text"}, + Field.VECTOR.value: { # Make sure the dimension is correct here + "type": "dense_vector", + "dims": dim, + "similarity": "cosine" + }, + Field.METADATA_KEY.value: { + "type": "object", + "properties": { + "doc_id": {"type": "keyword"} # Map doc_id to keyword type + } + } + } + } + self._client.indices.create(index=self._collection_name, mappings=mappings) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + +class ElasticSearchVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + + config = current_app.config + return ElasticSearchVector( + index_name=collection_name, + config=ElasticSearchConfig( + host=config.get('ELASTICSEARCH_HOST'), + port=config.get('ELASTICSEARCH_PORT'), + username=config.get('ELASTICSEARCH_USERNAME'), + password=config.get('ELASTICSEARCH_PASSWORD'), + ), + attributes=[] + ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 241b5a8414c476..05e75effefb23a 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -93,7 +93,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** @staticmethod def escape_str(value: Any) -> str: - return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in ("\\", "'") else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") @@ -118,21 +118,22 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs) + return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = kwargs.get('score_threshold') or 0.0 where_str = f"WHERE dist < {1 - score_threshold}" if \ self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" sql = f""" - SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} + SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} """ try: return [ Document( page_content=r["text"], + vector=r['vector'], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index d834e8ce14c51d..c95d202173b84d 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -192,7 +192,9 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: docs = [] for hit in response['hits']['hits']: metadata = hit['_source'].get(Field.METADATA_KEY.value) - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + vector = hit['_source'].get(Field.VECTOR.value) + page_content = hit['_source'].get(Field.CONTENT_KEY.value) + doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 4bd09b331d2839..aa2c6171c33367 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -234,16 +234,16 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: entities.append(token) with self._get_cursor() as cur: cur.execute( - f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", + f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", [" ACCUM ".join(entities)] ) docs = [] for record in cur: - metadata, text = record - docs.append(Document(page_content=text, metadata=metadata)) + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) return docs else: - return [Document(page_content="", metadata="")] + return [Document(page_content="", metadata={})] return [] def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 33ca5bc028b052..c9f2f35af03d75 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -152,8 +152,27 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # do not support bm25 search - return [] + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cur: + cur.execute( + f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score + FROM {self.table_name} + WHERE to_tsvector(text) @@ plainto_tsquery(%s) + ORDER BY score DESC + LIMIT {top_k}""", + # f"'{query}'" is required in order to account for whitespace in query + (f"'{query}'", f"'{query}'"), + ) + + docs = [] + + for record in cur: + metadata, text, score = record + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + + return docs def delete(self) -> None: with self._get_cursor() as cur: diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 77c3f6a27122f4..297bff928e8ae8 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -399,7 +399,6 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: document = self._document_from_scored_point( result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value ) - document.metadata['vector'] = result.vector documents.append(document) return documents @@ -418,6 +417,7 @@ def _document_from_scored_point( ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), + vector=scored_point.vector, metadata=scored_point.payload.get(metadata_payload_key) or {}, ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index fad60ecf45c151..3e9ca8e1fe7f4a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -71,6 +71,9 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory return RelytVectorFactory + case VectorType.ELASTICSEARCH: + from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory return TiDBVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 77495044df562c..317ca6abc8c89d 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -15,3 +15,4 @@ class VectorType(str, Enum): OPENSEARCH = 'opensearch' TENCENT = 'tencent' ORACLE = 'oracle' + ELASTICSEARCH = 'elasticsearch' diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 87fc5ff158e2d3..205fe850c35838 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -239,8 +239,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_obj = self._client.query.get(collection_name, properties) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) + query_obj = query_obj.with_additional(["vector"]) properties = ['text'] result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() if "errors" in result: @@ -248,7 +247,8 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - docs.append(Document(page_content=text, metadata=res)) + additional = res.pop('_additional') + docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 34a4e85e9720ce..0323b14a4a34fd 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -25,7 +25,7 @@ def extract(self) -> list[Document]: from unstructured.file_utils.filetype import FileType, detect_filetype unstructured_version = tuple( - [int(x) for x in __unstructured_version__.split(".")] + int(x) for x in __unstructured_version__.split(".") ) # check the file extension try: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index ac4a56319b5142..c3f0b75cfba5f1 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,9 +1,12 @@ """Abstract interface for document loader implementations.""" import datetime +import logging import mimetypes import os +import re import tempfile import uuid +import xml.etree.ElementTree as ET from urllib.parse import urlparse import requests @@ -16,6 +19,7 @@ from extensions.ext_storage import storage from models.model import UploadFile +logger = logging.getLogger(__name__) class WordExtractor(BaseExtractor): """Load docx files. @@ -117,19 +121,63 @@ def _extract_images_from_docx(self, doc, image_folder): return image_map - def _table_to_markdown(self, table): - markdown = "" - # deal with table headers + def _table_to_markdown(self, table, image_map): + markdown = [] + # calculate the total number of columns + total_cols = max(len(row.cells) for row in table.rows) + header_row = table.rows[0] - headers = [cell.text for cell in header_row.cells] - markdown += "| " + " | ".join(headers) + " |\n" - markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n" - # deal with table rows + headers = self._parse_row(header_row, image_map, total_cols) + markdown.append("| " + " | ".join(headers) + " |") + markdown.append("| " + " | ".join(["---"] * total_cols) + " |") + for row in table.rows[1:]: - row_cells = [cell.text for cell in row.cells] - markdown += "| " + " | ".join(row_cells) + " |\n" + row_cells = self._parse_row(row, image_map, total_cols) + markdown.append("| " + " | ".join(row_cells) + " |") + return "\n".join(markdown) - return markdown + def _parse_row(self, row, image_map, total_cols): + # Initialize a row, all of which are empty by default + row_cells = [""] * total_cols + col_index = 0 + for cell in row.cells: + # make sure the col_index is not out of range + while col_index < total_cols and row_cells[col_index] != "": + col_index += 1 + # if col_index is out of range the loop is jumped + if col_index >= total_cols: + break + cell_content = self._parse_cell(cell, image_map).strip() + cell_colspan = cell.grid_span if cell.grid_span else 1 + for i in range(cell_colspan): + if col_index + i < total_cols: + row_cells[col_index + i] = cell_content if i == 0 else "" + col_index += cell_colspan + return row_cells + + def _parse_cell(self, cell, image_map): + cell_content = [] + for paragraph in cell.paragraphs: + parsed_paragraph = self._parse_cell_paragraph(paragraph, image_map) + if parsed_paragraph: + cell_content.append(parsed_paragraph) + unique_content = list(dict.fromkeys(cell_content)) + return " ".join(unique_content) + + def _parse_cell_paragraph(self, paragraph, image_map): + paragraph_content = [] + for run in paragraph.runs: + if run.element.xpath('.//a:blip'): + for blip in run.element.xpath('.//a:blip'): + image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + image_part = paragraph.part.rels[image_id].target_part + + if image_part in image_map: + image_link = image_map[image_part] + paragraph_content.append(image_link) + else: + paragraph_content.append(run.text) + return "".join(paragraph_content).strip() def _parse_paragraph(self, paragraph, image_map): paragraph_content = [] @@ -153,10 +201,34 @@ def parse_docx(self, docx_path, image_folder): image_map = self._extract_images_from_docx(doc, image_folder) + hyperlinks_url = None + url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+') + for para in doc.paragraphs: + for run in para.runs: + if run.text and hyperlinks_url: + result = f' [{run.text}]({hyperlinks_url}) ' + run.text = result + hyperlinks_url = None + if 'HYPERLINK' in run.element.xml: + try: + xml = ET.XML(run.element.xml) + x_child = [c for c in xml.iter() if c is not None] + for x in x_child: + if x_child is None: + continue + if x.tag.endswith('instrText'): + for i in url_pattern.findall(x.text): + hyperlinks_url = str(i) + except Exception as e: + logger.error(e) + + + + def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if run.element.tag.endswith('r'): + if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'): drawing_elements = run.element.findall( './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') for drawing in drawing_elements: @@ -176,13 +248,14 @@ def parse_paragraph(paragraph): paragraphs = doc.paragraphs.copy() tables = doc.tables.copy() for element in doc.element.body: - if element.tag.endswith('p'): # paragraph - para = paragraphs.pop(0) - parsed_paragraph = parse_paragraph(para) - if parsed_paragraph: - content.append(parsed_paragraph) - elif element.tag.endswith('tbl'): # table - table = tables.pop(0) - content.append(self._table_to_markdown(table)) + if hasattr(element, 'tag'): + if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph + para = paragraphs.pop(0) + parsed_paragraph = parse_paragraph(para) + if parsed_paragraph: + content.append(parsed_paragraph) + elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table + table = tables.pop(0) + content.append(self._table_to_markdown(table,image_map)) return '\n'.join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 33e78ce8c5ccb0..176d0c1ed65cba 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -57,7 +57,7 @@ def _get_splitter(self, processing_rule: dict, character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0), + chunk_overlap=segmentation.get('chunk_overlap', 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 7bb675b149c0e6..6f3c1c5d343977 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -10,6 +10,8 @@ class Document(BaseModel): page_content: str + vector: Optional[list[float]] = None + """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d07f94adb77250..d8a78739826a31 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -159,10 +159,9 @@ def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document if 'score' in document.metadata: query_vector_scores.append(document.metadata['score']) else: - content_vector = document.metadata['vector'] # transform to NumPy vec1 = np.array(query_vector) - vec2 = np.array(document.metadata['vector']) + vec2 = np.array(document.vector) # calculate dot product dot_product = np.dot(vec1, vec2) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 2c5d920a9ad9d2..c970e3dafaa856 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -14,7 +14,8 @@ from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -278,6 +279,7 @@ def single_retrieve( query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, + reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), weights=retrieval_model_config.get('weights', None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) @@ -431,10 +433,12 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, dataset_id=dataset.id, query=query, top_k=top_k, - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] + reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) @@ -610,8 +614,9 @@ def calculate_vector_score(self, all_documents: list[Document], top_k: int, score_threshold: float) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold and document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata['score'] >= score_threshold: filter_documents.append(document) + if not filter_documents: return [] filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index fd714edf5e435b..6a0804f890db39 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -63,7 +63,7 @@ def split_text(self, text: str) -> list[str]: if self._fixed_separator: chunks = text.split(self._fixed_separator) else: - chunks = list(text) + chunks = [text] final_chunks = [] for chunk in chunks: diff --git a/api/core/tools/docs/zh_Hans/advanced_scale_out.md b/api/core/tools/docs/zh_Hans/advanced_scale_out.md index 93f81b033d60f2..0385dfe4e7bce0 100644 --- a/api/core/tools/docs/zh_Hans/advanced_scale_out.md +++ b/api/core/tools/docs/zh_Hans/advanced_scale_out.md @@ -21,6 +21,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create an image message :param image: the url of the image + :param save_as: save as :return: the image message """ ``` @@ -34,6 +35,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create a link message :param link: the url of the link + :param save_as: save as :return: the link message """ ``` @@ -47,6 +49,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create a text message :param text: the text of the message + :param save_as: save as :return: the text message """ ``` @@ -63,6 +66,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create a blob message :param blob: the blob + :param meta: meta + :param save_as: save as :return: the blob message """ ``` diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 569a1d3238f87f..e31dec55d21898 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -46,7 +46,7 @@ def value_of(cls, value: str) -> 'ToolProviderType': if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. @@ -68,7 +68,7 @@ def value_of(cls, value: str) -> 'ApiProviderSchemaType': if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. @@ -103,8 +103,8 @@ class MessageType(Enum): """ plain text, image url or link url """ - message: Union[str, bytes, dict] = None - meta: dict[str, Any] = None + message: str | bytes | dict | None = None + meta: dict[str, Any] | None = None save_as: str = '' class ToolInvokeMessageBinary(BaseModel): @@ -148,14 +148,14 @@ class ToolParameterForm(Enum): form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False - default: Optional[Union[int, str]] = None + default: Optional[Union[float, int, str]] = None min: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, + def get_simple_instance(cls, + name: str, llm_description: str, type: ToolParameterType, required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': """ get a simple tool parameter @@ -222,7 +222,7 @@ def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + @staticmethod def default(value: str) -> str: return "" @@ -290,7 +290,7 @@ def dict(self) -> dict: 'tenant_id': self.tenant_id, 'pool': [variable.model_dump() for variable in self.pool], } - + def set_text(self, tool_name: str, name: str, value: str) -> None: """ set a text variable @@ -301,7 +301,7 @@ def set_text(self, tool_name: str, name: str, value: str) -> None: variable = cast(ToolRuntimeTextVariable, variable) variable.value = value return - + variable = ToolRuntimeTextVariable( type=ToolRuntimeVariableType.TEXT, name=name, @@ -334,7 +334,7 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None: variable = cast(ToolRuntimeImageVariable, variable) variable.value = value return - + variable = ToolRuntimeImageVariable( type=ToolRuntimeVariableType.IMAGE, name=name, @@ -388,21 +388,21 @@ def empty(cls) -> 'ToolInvokeMeta': Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) - + @classmethod def error_instance(cls, error: str) -> 'ToolInvokeMeta': """ Get an instance of ToolInvokeMeta with error """ return cls(time_cost=0.0, error=error, tool_config={}) - + def to_dict(self) -> dict: return { 'time_cost': self.time_cost, 'error': self.error, 'tool_config': self.tool_config, } - + class ToolLabel(BaseModel): """ Tool label @@ -416,4 +416,4 @@ class ToolInvokeFrom(Enum): Enum class for tool invoke """ WORKFLOW = "workflow" - AGENT = "agent" \ No newline at end of file + AGENT = "agent" diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 25d9f403a0fbbe..40c3356116770b 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,5 +1,6 @@ - google - bing +- perplexity - duckduckgo - searchapi - serper @@ -10,6 +11,7 @@ - wikipedia - nominatim - yahoo +- alphavantage - arxiv - pubmed - stablediffusion @@ -30,5 +32,7 @@ - dingtalk - feishu - feishu_base +- feishu_document +- feishu_message - slack - tianditu diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4a032..062668fc5bf8bf 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.helper.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_tool_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) def name_func(provider: UserToolProvider) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg new file mode 100644 index 00000000000000..785432943bc148 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg @@ -0,0 +1,7 @@ + + + 形状结合 + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py new file mode 100644 index 00000000000000..01f2acfb5b6310 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AlphaVantageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + QueryStockTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "code": "AAPL", # Apple Inc. + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml new file mode 100644 index 00000000000000..710510cfd8ed4a --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml @@ -0,0 +1,31 @@ +identity: + author: zhuhao + name: alphavantage + label: + en_US: AlphaVantage + zh_Hans: AlphaVantage + pt_BR: AlphaVantage + description: + en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。 + pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + icon: icon.svg + tags: + - finance +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: AlphaVantage API key + zh_Hans: AlphaVantage API key + pt_BR: AlphaVantage API key + placeholder: + en_US: Please input your AlphaVantage API key + zh_Hans: 请输入你的 AlphaVantage API key + pt_BR: Please input your AlphaVantage API key + help: + en_US: Get your AlphaVantage API key from AlphaVantage + zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key + pt_BR: Get your AlphaVantage API key from AlphaVantage + url: https://www.alphavantage.co/support/#api-key diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py new file mode 100644 index 00000000000000..5c379b746d5db9 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -0,0 +1,49 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query" + + +class QueryStockTool(BuiltinTool): + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + stock_code = tool_parameters.get('code', '') + if not stock_code: + return self.create_text_message('Please tell me your stock code') + + if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + return self.create_text_message("Alpha Vantage API key is required.") + + params = { + "function": "TIME_SERIES_DAILY", + "symbol": stock_code, + "outputsize": "compact", + "datatype": "json", + "apikey": self.runtime.credentials['api_key'] + } + response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) + response.raise_for_status() + result = self._handle_response(response.json()) + return self.create_json_message(result) + + def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: + result = response.get('Time Series (Daily)', {}) + if not result: + return {} + stock_result = {} + for k, v in result.items(): + stock_result[k] = {} + stock_result[k]['open'] = v.get('1. open') + stock_result[k]['high'] = v.get('2. high') + stock_result[k]['low'] = v.get('3. low') + stock_result[k]['close'] = v.get('4. close') + stock_result[k]['volume'] = v.get('5. volume') + return stock_result diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml new file mode 100644 index 00000000000000..d89f34e373f9fa --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml @@ -0,0 +1,27 @@ +identity: + name: query_stock + author: zhuhao + label: + en_US: query_stock + zh_Hans: query_stock + pt_BR: query_stock +description: + human: + en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol. + zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。 + pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol + llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol +parameters: + - name: code + type: string + required: true + label: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + human_description: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + llm_description: stock code for query from alphavantage + form: llm diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml index 63a8c99d97f90c..e256748e8f7188 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 3 - zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档 pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml index ba0b271a1c716c..1de3f599b6ac02 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of CogView 3 - zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 CogView 3 的官方文档 pt_BR: Image prompt, you can check the official documentation of CogView 3 llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/crossref/_assets/icon.svg b/api/core/tools/provider/builtin/crossref/_assets/icon.svg new file mode 100644 index 00000000000000..aa629de7cb1660 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/_assets/icon.svg @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py new file mode 100644 index 00000000000000..404e483e0d23ab --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -0,0 +1,20 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class CrossRefProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + CrossRefQueryDOITool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "doi": '10.1007/s00894-022-05373-8', + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/crossref/crossref.yaml b/api/core/tools/provider/builtin/crossref/crossref.yaml new file mode 100644 index 00000000000000..da67fbec3a480b --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/crossref.yaml @@ -0,0 +1,29 @@ +identity: + author: Sakura4036 + name: crossref + label: + en_US: CrossRef + zh_Hans: CrossRef + description: + en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers. + zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接,使得读者能够非常便捷地获取文献全文。 + icon: icon.svg + tags: + - search +credentials_for_provider: + mailto: + type: text-input + required: true + label: + en_US: email address + zh_Hans: email地址 + pt_BR: email address + placeholder: + en_US: Please input your email address + zh_Hans: 请输入你的email地址 + pt_BR: Please input your email address + help: + en_US: According to the requirements of Crossref, an email address is required + zh_Hans: 根据Crossref的要求,需要提供一个邮箱地址 + pt_BR: According to the requirements of Crossref, an email address is required + url: https://api.crossref.org/swagger-ui/index.html diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py new file mode 100644 index 00000000000000..a43c0989e40613 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -0,0 +1,25 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class CrossRefQueryDOITool(BuiltinTool): + """ + Tool for querying the metadata of a publication using its DOI. + """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get('doi') + if not doi: + raise ToolParameterValidationError('doi is required.') + # doc: https://github.com/CrossRef/rest-api-doc + url = f"https://api.crossref.org/works/{doi}" + response = requests.get(url) + response.raise_for_status() + response = response.json() + message = response.get('message', {}) + + return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml b/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml new file mode 100644 index 00000000000000..9c16da25edf2b3 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml @@ -0,0 +1,23 @@ +identity: + name: crossref_query_doi + author: Sakura4036 + label: + en_US: CrossRef Query DOI + zh_Hans: CrossRef DOI 查询 + pt_BR: CrossRef Query DOI +description: + human: + en_US: A tool for searching literature information using CrossRef by DOI. + zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。 + pt_BR: A tool for searching literature information using CrossRef by DOI. + llm: A tool for searching literature information using CrossRef by DOI. +parameters: + - name: doi + type: string + required: true + label: + en_US: DOI + zh_Hans: DOI + pt_BR: DOI + llm_description: DOI for searching in CrossRef + form: llm diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py new file mode 100644 index 00000000000000..946aa6dc947a20 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -0,0 +1,120 @@ +import time +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +def convert_time_str_to_seconds(time_str: str) -> int: + """ + Convert a time string to seconds. + example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 + """ + time_str = time_str.lower().strip().replace(' ', '') + seconds = 0 + if 'h' in time_str: + hours, time_str = time_str.split('h') + seconds += int(hours) * 3600 + if 'm' in time_str: + minutes, time_str = time_str.split('m') + seconds += int(minutes) * 60 + if 's' in time_str: + seconds += int(time_str.replace('s', '')) + return seconds + + +class CrossRefQueryTitleAPI: + """ + Tool for querying the metadata of a publication using its title. + Crossref API doc: https://github.com/CrossRef/rest-api-doc + """ + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" + rate_limit: int = 50 + rate_interval: float = 1 + max_limit: int = 1000 + + def __init__(self, mailto: str): + self.mailto = mailto + + def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + """ + Query the metadata of a publication using its title. + :param query: the title of the publication + :param rows: the number of results to return + :param sort: the sort field + :param order: the sort order + :param fuzzy_query: whether to return all items that match the query + """ + url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto) + response = requests.get(url) + response.raise_for_status() + rate_limit = int(response.headers['x-ratelimit-limit']) + # convert time string to seconds + rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval']) + + self.rate_limit = rate_limit + self.rate_interval = rate_interval + + response = response.json() + if response['status'] != 'ok': + return [] + + message = response['message'] + if fuzzy_query: + # fuzzy query return all items + return message['items'] + else: + for paper in message['items']: + title = paper['title'][0] + if title.lower() != query.lower(): + continue + return [paper] + return [] + + def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + """ + Query the metadata of a publication using its title. + :param query: the title of the publication + :param rows: the number of results to return + :param sort: the sort field + :param order: the sort order + :param fuzzy_query: whether to return all items that match the query + """ + rows = min(rows, self.max_limit) + if rows > self.rate_limit: + # query multiple times + query_times = rows // self.rate_limit + 1 + results = [] + + for i in range(query_times): + result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query) + if fuzzy_query: + results.extend(result) + else: + # fuzzy_query=False, only one result + if result: + return result + time.sleep(self.rate_interval) + return results + else: + # query once + return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query) + + +class CrossRefQueryTitleTool(BuiltinTool): + """ + Tool for querying the metadata of a publication using its title. + """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get('query') + fuzzy_query = tool_parameters.get('fuzzy_query', False) + rows = tool_parameters.get('rows', 3) + sort = tool_parameters.get('sort', 'relevance') + order = tool_parameters.get('order', 'desc') + mailto = self.runtime.credentials['mailto'] + + result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) + + return [self.create_json_message(r) for r in result] diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.yaml b/api/core/tools/provider/builtin/crossref/tools/query_title.yaml new file mode 100644 index 00000000000000..5579c77f5293d3 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.yaml @@ -0,0 +1,105 @@ +identity: + name: crossref_query_title + author: Sakura4036 + label: + en_US: CrossRef Title Query + zh_Hans: CrossRef 标题查询 + pt_BR: CrossRef Title Query +description: + human: + en_US: A tool for querying literature information using CrossRef by title. + zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。 + pt_BR: A tool for querying literature information using CrossRef by title. + llm: A tool for querying literature information using CrossRef by title. +parameters: + - name: query + type: string + required: true + label: + en_US: 标题 + zh_Hans: 查询语句 + pt_BR: 标题 + human_description: + en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years + zh_Hans: 用于搜索文献信息,有助于查找引用。包括标题,作者,ISSN和出版年份 + pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years + llm_description: key words for querying in Web of Science + form: llm + - name: fuzzy_query + type: boolean + default: false + label: + en_US: Whether to fuzzy search + zh_Hans: 是否模糊搜索 + pt_BR: Whether to fuzzy search + human_description: + en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none + zh_Hans: 用于选择搜索类型,模糊搜索返回更多结果,精确搜索返回1条结果或无 + pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none + form: form + - name: limit + type: number + required: false + label: + en_US: max query number + zh_Hans: 最大搜索数 + pt_BR: max query number + human_description: + en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches) + zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数) + pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches) + form: llm + default: 50 + - name: sort + type: select + required: true + options: + - value: relevance + label: + en_US: relevance + zh_Hans: 相关性 + pt_BR: relevance + - value: published + label: + en_US: publication date + zh_Hans: 出版日期 + pt_BR: publication date + - value: references-count + label: + en_US: references-count + zh_Hans: 引用次数 + pt_BR: references-count + default: relevance + label: + en_US: sorting field + zh_Hans: 排序字段 + pt_BR: sorting field + human_description: + en_US: Sorting of query results + zh_Hans: 检索结果的排序字段 + pt_BR: Sorting of query results + form: form + - name: order + type: select + required: true + options: + - value: desc + label: + en_US: descending + zh_Hans: 降序 + pt_BR: descending + - value: asc + label: + en_US: ascending + zh_Hans: 升序 + pt_BR: ascending + default: desc + label: + en_US: Order + zh_Hans: 排序 + pt_BR: Order + human_description: + en_US: Order of query results + zh_Hans: 检索结果的排序方式 + pt_BR: Order of query results + form: form diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml index 90c73ecc57bb7b..e43e5df8cddd9b 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml @@ -24,7 +24,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 2 - zh_Hans: 图像提示词,您可以查看DallE 2 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 2 的官方文档 pt_BR: Image prompt, you can check the official documentation of DallE 2 llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml index 7ba5c56889c7e4..0cea8af761e1e5 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 3 - zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档 pt_BR: Image prompt, you can check the official documentation of DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/did/_assets/icon.svg b/api/core/tools/provider/builtin/did/_assets/icon.svg new file mode 100644 index 00000000000000..c477d7cb71dea2 --- /dev/null +++ b/api/core/tools/provider/builtin/did/_assets/icon.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did.py b/api/core/tools/provider/builtin/did/did.py new file mode 100644 index 00000000000000..b4bf172131448d --- /dev/null +++ b/api/core/tools/provider/builtin/did/did.py @@ -0,0 +1,21 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.did.tools.talks import TalksTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class DIDProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + # Example validation using the D-ID talks tool + TalksTool().fork_tool_runtime( + runtime={"credentials": credentials} + ).invoke( + user_id='', + tool_parameters={ + "source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png", + "text_input": "Hello, welcome to use D-ID tool in Dify", + } + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did.yaml b/api/core/tools/provider/builtin/did/did.yaml new file mode 100644 index 00000000000000..a70b71812e4648 --- /dev/null +++ b/api/core/tools/provider/builtin/did/did.yaml @@ -0,0 +1,28 @@ +identity: + author: Matri Qi + name: did + label: + en_US: D-ID + description: + en_US: D-ID is a tool enabling the creation of high-quality, custom videos of Digital Humans from a single image. + icon: icon.svg + tags: + - videos +credentials_for_provider: + did_api_key: + type: secret-input + required: true + label: + en_US: D-ID API Key + placeholder: + en_US: Please input your D-ID API key + help: + en_US: Get your D-ID API key from your D-ID account settings. + url: https://studio.d-id.com/account-settings + base_url: + type: text-input + required: false + label: + en_US: D-ID server's Base URL + placeholder: + en_US: https://api.d-id.com diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py new file mode 100644 index 00000000000000..964e82b729319e --- /dev/null +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -0,0 +1,87 @@ +import logging +import time +from collections.abc import Mapping +from typing import Any + +import requests +from requests.exceptions import HTTPError + +logger = logging.getLogger(__name__) + + +class DIDApp: + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self.api_key = api_key + self.base_url = base_url or 'https://api.d-id.com' + if not self.api_key: + raise ValueError('API key is required') + + def _prepare_headers(self, idempotency_key: str | None = None): + headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'} + if idempotency_key: + headers['Idempotency-Key'] = idempotency_key + return headers + + def _request( + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, + ) -> Mapping[str, Any] | None: + for i in range(retries): + try: + response = requests.request(method, url, json=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: + time.sleep(backoff_factor * (2**i)) + else: + raise + return None + + def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): + endpoint = f'{self.base_url}/talks' + headers = self._prepare_headers(idempotency_key) + data = kwargs['params'] + logger.debug(f'Send request to {endpoint=} body={data}') + response = self._request('POST', endpoint, data, headers) + if response is None: + raise HTTPError('Failed to initiate D-ID talks after multiple retries') + id: str = response['id'] + if wait: + return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval) + return id + + def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): + endpoint = f'{self.base_url}/animations' + headers = self._prepare_headers(idempotency_key) + data = kwargs['params'] + logger.debug(f'Send request to {endpoint=} body={data}') + response = self._request('POST', endpoint, data, headers) + if response is None: + raise HTTPError('Failed to initiate D-ID talks after multiple retries') + id: str = response['id'] + if wait: + return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval) + return id + + def check_did_status(self, target: str, id: str): + endpoint = f'{self.base_url}/{target}/{id}' + headers = self._prepare_headers() + response = self._request('GET', endpoint, headers=headers) + if response is None: + raise HTTPError(f'Failed to check status for talks {id} after multiple retries') + return response + + def _monitor_job_status(self, target: str, id: str, poll_interval: int): + while True: + status = self.check_did_status(target=target, id=id) + if status['status'] == 'done': + return status + elif status['status'] == 'error' or status['status'] == 'rejected': + raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') + time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/did/tools/animations.py b/api/core/tools/provider/builtin/did/tools/animations.py new file mode 100644 index 00000000000000..e1d9de603fbb7a --- /dev/null +++ b/api/core/tools/provider/builtin/did/tools/animations.py @@ -0,0 +1,49 @@ +import json +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.did.did_appx import DIDApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class AnimationsTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + + driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None + + config = { + 'stitch': tool_parameters.get('stitch', True), + 'mute': tool_parameters.get('mute'), + 'result_format': tool_parameters.get('result_format') or 'mp4', + } + config = {k: v for k, v in config.items() if v is not None and v != ''} + + options = { + 'source_url': tool_parameters['source_url'], + 'driver_url': tool_parameters.get('driver_url'), + 'config': config, + } + options = {k: v for k, v in options.items() if v is not None and v != ''} + + if not options.get('source_url'): + raise ValueError('Source URL is required') + + if config.get('logo_url'): + if not config.get('logo_x'): + raise ValueError('Logo X position is required when logo URL is provided') + if not config.get('logo_y'): + raise ValueError('Logo Y position is required when logo URL is provided') + + animations_result = app.animations(params=options, wait=True) + + if not isinstance(animations_result, str): + animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4) + + if not animations_result: + return self.create_text_message('D-ID animations request failed.') + + return self.create_text_message(animations_result) diff --git a/api/core/tools/provider/builtin/did/tools/animations.yaml b/api/core/tools/provider/builtin/did/tools/animations.yaml new file mode 100644 index 00000000000000..2a2036c7b2a88f --- /dev/null +++ b/api/core/tools/provider/builtin/did/tools/animations.yaml @@ -0,0 +1,86 @@ +identity: + name: animations + author: Matri Qi + label: + en_US: Animations +description: + human: + en_US: Animations enables to create videos matching head movements, expressions, emotions, and voice from a driver video and image. + llm: Animations enables to create videos matching head movements, expressions, emotions, and voice from a driver video and image. +parameters: + - name: source_url + type: string + required: true + label: + en_US: source url + human_description: + en_US: The URL of the source image to be animated by the driver video, or a selection from the list of provided studio actors. + llm_description: The URL of the source image to be animated by the driver video, or a selection from the list of provided studio actors. + form: llm + - name: driver_url + type: string + required: false + label: + en_US: driver url + human_description: + en_US: The URL of the driver video to drive the animation, or a provided driver name from D-ID. + form: form + - name: mute + type: boolean + required: false + label: + en_US: mute + human_description: + en_US: Mutes the driver sound in the animated video result, defaults to true + form: form + - name: stitch + type: boolean + required: false + label: + en_US: stitch + human_description: + en_US: If enabled, the driver video will be stitched with the animationing head video. + form: form + - name: logo_url + type: string + required: false + label: + en_US: logo url + human_description: + en_US: The URL of the logo image to be added to the animation video. + form: form + - name: logo_x + type: number + required: false + label: + en_US: logo position x + human_description: + en_US: The x position of the logo image in the animation video. It's required when logo url is provided. + form: form + - name: logo_y + type: number + required: false + label: + en_US: logo position y + human_description: + en_US: The y position of the logo image in the animation video. It's required when logo url is provided. + form: form + - name: result_format + type: string + default: mp4 + required: false + label: + en_US: result format + human_description: + en_US: The format of the result video. + form: form + options: + - value: mp4 + label: + en_US: mp4 + - value: gif + label: + en_US: gif + - value: mov + label: + en_US: mov diff --git a/api/core/tools/provider/builtin/did/tools/talks.py b/api/core/tools/provider/builtin/did/tools/talks.py new file mode 100644 index 00000000000000..06b2c4cb2f6049 --- /dev/null +++ b/api/core/tools/provider/builtin/did/tools/talks.py @@ -0,0 +1,65 @@ +import json +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.did.did_appx import DIDApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class TalksTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + + driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None + + script = { + 'type': tool_parameters.get('script_type') or 'text', + 'input': tool_parameters.get('text_input'), + 'audio_url': tool_parameters.get('audio_url'), + 'reduce_noise': tool_parameters.get('audio_reduce_noise', False), + } + script = {k: v for k, v in script.items() if v is not None and v != ''} + config = { + 'stitch': tool_parameters.get('stitch', True), + 'sharpen': tool_parameters.get('sharpen'), + 'fluent': tool_parameters.get('fluent'), + 'result_format': tool_parameters.get('result_format') or 'mp4', + 'pad_audio': tool_parameters.get('pad_audio'), + 'driver_expressions': driver_expressions, + } + config = {k: v for k, v in config.items() if v is not None and v != ''} + + options = { + 'source_url': tool_parameters['source_url'], + 'driver_url': tool_parameters.get('driver_url'), + 'script': script, + 'config': config, + } + options = {k: v for k, v in options.items() if v is not None and v != ''} + + if not options.get('source_url'): + raise ValueError('Source URL is required') + + if script.get('type') == 'audio': + script.pop('input', None) + if not script.get('audio_url'): + raise ValueError('Audio URL is required for audio script type') + + if script.get('type') == 'text': + script.pop('audio_url', None) + script.pop('reduce_noise', None) + if not script.get('input'): + raise ValueError('Text input is required for text script type') + + talks_result = app.talks(params=options, wait=True) + + if not isinstance(talks_result, str): + talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4) + + if not talks_result: + return self.create_text_message('D-ID talks request failed.') + + return self.create_text_message(talks_result) diff --git a/api/core/tools/provider/builtin/did/tools/talks.yaml b/api/core/tools/provider/builtin/did/tools/talks.yaml new file mode 100644 index 00000000000000..88d430512923e4 --- /dev/null +++ b/api/core/tools/provider/builtin/did/tools/talks.yaml @@ -0,0 +1,126 @@ +identity: + name: talks + author: Matri Qi + label: + en_US: Talks +description: + human: + en_US: Talks enables the creation of realistic talking head videos from text or audio inputs. + llm: Talks enables the creation of realistic talking head videos from text or audio inputs. +parameters: + - name: source_url + type: string + required: true + label: + en_US: source url + human_description: + en_US: The URL of the source image to be animated by the driver video, or a selection from the list of provided studio actors. + llm_description: The URL of the source image to be animated by the driver video, or a selection from the list of provided studio actors. + form: llm + - name: driver_url + type: string + required: false + label: + en_US: driver url + human_description: + en_US: The URL of the driver video to drive the talk, or a provided driver name from D-ID. + form: form + - name: script_type + type: string + required: false + label: + en_US: script type + human_description: + en_US: The type of the script. + form: form + options: + - value: text + label: + en_US: text + - value: audio + label: + en_US: audio + - name: text_input + type: string + required: false + label: + en_US: text input + human_description: + en_US: The text input to be spoken by the talking head. Required when script type is text. + form: form + - name: audio_url + type: string + required: false + label: + en_US: audio url + human_description: + en_US: The URL of the audio file to be spoken by the talking head. Required when script type is audio. + form: form + - name: audio_reduce_noise + type: boolean + required: false + label: + en_US: audio reduce noise + human_description: + en_US: If enabled, the audio will be processed to reduce noise before being spoken by the talking head. It only works when script type is audio. + form: form + - name: stitch + type: boolean + required: false + label: + en_US: stitch + human_description: + en_US: If enabled, the driver video will be stitched with the talking head video. + form: form + - name: sharpen + type: boolean + required: false + label: + en_US: sharpen + human_description: + en_US: If enabled, the talking head video will be sharpened. + form: form + - name: result_format + type: string + required: false + label: + en_US: result format + human_description: + en_US: The format of the result video. + form: form + options: + - value: mp4 + label: + en_US: mp4 + - value: gif + label: + en_US: gif + - value: mov + label: + en_US: mov + - name: fluent + type: boolean + required: false + label: + en_US: fluent + human_description: + en_US: Interpolate between the last & first frames of the driver video When used together with pad_audio can create a seamless transition between videos of the same driver + form: form + - name: pad_audio + type: number + required: false + label: + en_US: pad audio + human_description: + en_US: Pad the audio with silence at the end (given in seconds) Will increase the video duration & the credits it consumes + form: form + min: 1 + max: 60 + - name: driver_expressions + type: string + required: false + label: + en_US: driver expressions + human_description: + en_US: timed expressions for animation. It should be an JSON array style string. Take D-ID documentation(https://docs.d-id.com/reference/createtalk) for more information. + form: form diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml index 1913eed1d1fa3f..21cbae6bd3e002 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml @@ -25,9 +25,9 @@ parameters: type: select required: true options: - - value: gpt-3.5 + - value: gpt-4o-mini label: - en_US: GPT-3.5 + en_US: GPT-4o-mini - value: claude-3-haiku label: en_US: Claude 3 diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index 442f29f33dfcef..dfaeb734d8f667 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -21,23 +21,16 @@ class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. """ - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - query = tool_parameters.get('query', '') - result_type = tool_parameters.get('result_type', 'text') - max_results = tool_parameters.get('max_results', 10) + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + query = tool_parameters.get('query') + max_results = tool_parameters.get('max_results', 5) require_summary = tool_parameters.get('require_summary', False) response = DDGS().text(query, max_results=max_results) - - if result_type == 'link': - results = [f"[{res.get('title')}]({res.get('href')})" for res in response] - results = "\n".join(results) - return self.create_link_message(link=results) - results = [res.get("body") for res in response] - results = "\n".join(results) if require_summary: + results = "\n".join([res.get("body") for res in response]) results = self.summary_results(user_id=user_id, content=results, query=query) - return self.create_text_message(text=results) + return self.create_text_message(text=results) + return [self.create_json_message(res) for res in response] def summary_results(self, user_id: str, content: str, query: str) -> str: prompt = SUMMARY_PROMPT.format(query=query, content=content) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml index c427a37fe6a893..333c0cb093dbd2 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml @@ -28,29 +28,6 @@ parameters: label: en_US: Max results zh_Hans: 最大结果数量 - human_description: - en_US: The max results. - zh_Hans: 最大结果数量 - form: form - - name: result_type - type: select - required: true - options: - - value: text - label: - en_US: text - zh_Hans: 文本 - - value: link - label: - en_US: link - zh_Hans: 链接 - default: text - label: - en_US: Result type - zh_Hans: 结果类型 - human_description: - en_US: used for selecting the result type, text or link - zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 form: form - name: require_summary type: boolean diff --git a/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg new file mode 100644 index 00000000000000..5a0a6416b3db32 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py new file mode 100644 index 00000000000000..c4f8f26e2c9a9b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -0,0 +1,15 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class FeishuDocumentProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + app_id = credentials.get('app_id') + app_secret = credentials.get('app_secret') + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert FeishuRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml b/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml new file mode 100644 index 00000000000000..8eaa6b27049c6b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml @@ -0,0 +1,34 @@ +identity: + author: Doug Lea + name: feishu_document + label: + en_US: Lark Cloud Document + zh_Hans: 飞书云文档 + description: + en_US: Lark Cloud Document + zh_Hans: 飞书云文档 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.feishu.cn + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py new file mode 100644 index 00000000000000..0ff82e621b5722 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + title = tool_parameters.get('title') + content = tool_parameters.get('content') + folder_token = tool_parameters.get('folder_token') + + res = client.create_document(title, content, folder_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml new file mode 100644 index 00000000000000..ddf2729f0e4b5c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml @@ -0,0 +1,47 @@ +identity: + name: create_document + author: Doug Lea + label: + en_US: Create Lark document + zh_Hans: 创建飞书文档 +description: + human: + en_US: Create Lark document + zh_Hans: 创建飞书文档,支持创建空文档和带内容的文档,支持 markdown 语法创建。 + llm: A tool for creating Feishu documents. +parameters: + - name: title + type: string + required: false + label: + en_US: Document title + zh_Hans: 文档标题 + human_description: + en_US: Document title, only supports plain text content. + zh_Hans: 文档标题,只支持纯文本内容。 + llm_description: 文档标题,只支持纯文本内容,可以为空。 + form: llm + + - name: content + type: string + required: false + label: + en_US: Document content + zh_Hans: 文档内容 + human_description: + en_US: Document content, supports markdown syntax, can be empty. + zh_Hans: 文档内容,支持 markdown 语法,可以为空。 + llm_description: 文档内容,支持 markdown 语法,可以为空。 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: folder_token + zh_Hans: 文档所在文件夹的 Token + human_description: + en_US: The token of the folder where the document is located. If it is not passed or is empty, it means the root directory. + zh_Hans: 文档所在文件夹的 Token,不传或传空表示根目录。 + llm_description: 文档所在文件夹的 Token,不传或传空表示根目录。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py new file mode 100644 index 00000000000000..16ef90908bda57 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py @@ -0,0 +1,17 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetDocumentRawContentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get('document_id') + + res = client.get_document_raw_content(document_id) + return self.create_json_message(res) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml new file mode 100644 index 00000000000000..e5b0937e03de86 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml @@ -0,0 +1,23 @@ +identity: + name: get_document_raw_content + author: Doug Lea + label: + en_US: Get Document Raw Content + zh_Hans: 获取文档纯文本内容 +description: + human: + en_US: Get document raw content + zh_Hans: 获取文档纯文本内容 + llm: A tool for getting the plain text content of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique ID of Feishu document document_id + zh_Hans: 飞书文档的唯一标识 document_id + llm_description: 飞书文档的唯一标识 document_id + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py new file mode 100644 index 00000000000000..97d17bdb04767a --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListDocumentBlockTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get('document_id') + page_size = tool_parameters.get('page_size', 500) + page_token = tool_parameters.get('page_token', '') + + res = client.list_document_block(document_id, page_token, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml new file mode 100644 index 00000000000000..d51e5a837cc287 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml @@ -0,0 +1,48 @@ +identity: + name: list_document_block + author: Doug Lea + label: + en_US: List Document Block + zh_Hans: 获取飞书文档所有块 +description: + human: + en_US: List document block + zh_Hans: 获取飞书文档所有块的富文本内容并分页返回。 + llm: A tool to get all blocks of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique ID of Feishu document document_id + zh_Hans: 飞书文档的唯一标识 document_id + llm_description: 飞书文档的唯一标识 document_id + form: llm + + - name: page_size + type: number + required: false + default: 500 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: Paging size, the default and maximum value is 500. + zh_Hans: 分页大小, 默认值和最大值为 500。 + llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination tag, used to paginate query results so that more items can be obtained in the next traversal. + zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py new file mode 100644 index 00000000000000..914a44dce62483 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get('document_id') + content = tool_parameters.get('content') + position = tool_parameters.get('position') + + res = client.write_document(document_id, content, position) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml new file mode 100644 index 00000000000000..8ee219d4a7ab84 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml @@ -0,0 +1,56 @@ +identity: + name: write_document + author: Doug Lea + label: + en_US: Write Document + zh_Hans: 在飞书文档中新增内容 +description: + human: + en_US: Adding new content to Lark documents + zh_Hans: 在飞书文档中新增内容 + llm: A tool for adding new content to Lark documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique ID of Feishu document document_id + zh_Hans: 飞书文档的唯一标识 document_id + llm_description: 飞书文档的唯一标识 document_id + form: llm + + - name: content + type: string + required: true + label: + en_US: document content + zh_Hans: 文档内容 + human_description: + en_US: Document content, supports markdown syntax, can be empty. + zh_Hans: 文档内容,支持 markdown 语法,可以为空。 + llm_description: + form: llm + + - name: position + type: select + required: true + default: start + label: + en_US: Choose where to add content + zh_Hans: 选择添加内容的位置 + human_description: + en_US: Please fill in start or end to add content at the beginning or end of the document respectively. + zh_Hans: 请填入 start 或 end, 分别表示在文档开头(start)或结尾(end)添加内容。 + form: llm + options: + - value: start + label: + en_US: start + zh_Hans: 在文档开头添加内容 + - value: end + label: + en_US: end + zh_Hans: 在文档结尾添加内容 diff --git a/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg new file mode 100644 index 00000000000000..222a1571f9bbbb --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg @@ -0,0 +1,19 @@ + + + + diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py new file mode 100644 index 00000000000000..6d7fed330c545e --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -0,0 +1,15 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class FeishuMessageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + app_id = credentials.get('app_id') + app_secret = credentials.get('app_secret') + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert FeishuRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml b/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml new file mode 100644 index 00000000000000..1bd8953dddcb24 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml @@ -0,0 +1,34 @@ +identity: + author: Doug Lea + name: feishu_message + label: + en_US: Lark Message + zh_Hans: 飞书消息 + description: + en_US: Lark Message + zh_Hans: 飞书消息 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.feishu.cn + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py new file mode 100644 index 00000000000000..74f6866ba3c3a0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SendBotMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + receive_id_type = tool_parameters.get('receive_id_type') + receive_id = tool_parameters.get('receive_id') + msg_type = tool_parameters.get('msg_type') + content = tool_parameters.get('content') + + res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml new file mode 100644 index 00000000000000..6e398b18ab3aee --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml @@ -0,0 +1,91 @@ +identity: + name: send_bot_message + author: Doug Lea + label: + en_US: Send Bot Message + zh_Hans: 发送飞书应用消息 +description: + human: + en_US: Send bot message + zh_Hans: 发送飞书应用消息 + llm: A tool for sending Feishu application messages. +parameters: + - name: receive_id_type + type: select + required: true + options: + - value: open_id + label: + en_US: open id + zh_Hans: open id + - value: union_id + label: + en_US: union id + zh_Hans: union id + - value: user_id + label: + en_US: user id + zh_Hans: user id + - value: email + label: + en_US: email + zh_Hans: email + - value: chat_id + label: + en_US: chat id + zh_Hans: chat id + label: + en_US: User ID Type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID Type + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id。 + form: llm + + - name: receive_id + type: string + required: true + label: + en_US: Receive Id + zh_Hans: 消息接收者的 ID + human_description: + en_US: The ID of the message receiver. The ID type should correspond to the query parameter receive_id_type. + zh_Hans: 消息接收者的 ID,ID 类型应与查询参数 receive_id_type 对应。 + llm_description: 消息接收者的 ID,ID 类型应与查询参数 receive_id_type 对应。 + form: llm + + - name: msg_type + type: string + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: message card + zh_Hans: 消息卡片 + label: + en_US: Message type + zh_Hans: 消息类型 + human_description: + en_US: Message type, optional values are, text (text), interactive (message card). + zh_Hans: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + llm_description: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + form: llm + + - name: content + type: string + required: true + label: + en_US: Message content + zh_Hans: 消息内容 + human_description: + en_US: Message content + zh_Hans: | + 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容, + 具体格式说明参考:https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json + llm_description: 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py new file mode 100644 index 00000000000000..7159f59ffa5c67 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SendWebhookMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + webhook = tool_parameters.get('webhook') + msg_type = tool_parameters.get('msg_type') + content = tool_parameters.get('content') + + res = client.send_webhook_message(webhook, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml new file mode 100644 index 00000000000000..8b39ce4874d506 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml @@ -0,0 +1,58 @@ +identity: + name: send_webhook_message + author: Doug Lea + label: + en_US: Send Webhook Message + zh_Hans: 使用自定义机器人发送飞书消息 +description: + human: + en_US: Send webhook message + zh_Hans: 使用自定义机器人发送飞书消息 + llm: A tool for sending Lark messages using a custom robot. +parameters: + - name: webhook + type: string + required: true + label: + en_US: webhook + zh_Hans: webhook 的地址 + human_description: + en_US: The address of the webhook + zh_Hans: webhook 的地址 + llm_description: webhook 的地址 + form: llm + + - name: msg_type + type: string + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: message card + zh_Hans: 消息卡片 + label: + en_US: Message type + zh_Hans: 消息类型 + human_description: + en_US: Message type, optional values are, text (text), interactive (message card). + zh_Hans: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + llm_description: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + form: llm + + - name: content + type: string + required: true + label: + en_US: Message content + zh_Hans: 消息内容 + human_description: + en_US: Message content + zh_Hans: | + 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容, + 具体格式说明参考:https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json + llm_description: 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容。 + form: llm diff --git a/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg b/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg new file mode 100644 index 00000000000000..07734077d5d300 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py new file mode 100644 index 00000000000000..0c13ec662a4f98 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -0,0 +1,34 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GitlabProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") + + if 'site_url' not in credentials or not credentials.get('site_url'): + site_url = 'https://gitlab.com' + else: + site_url = credentials.get('site_url') + + try: + headers = { + "Content-Type": "application/vnd.text+json", + "Authorization": f"Bearer {credentials.get('access_tokens')}", + } + + response = requests.get( + url= f"{site_url}/api/v4/user", + headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError((response.json()).get('message')) + except Exception as e: + raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.yaml b/api/core/tools/provider/builtin/gitlab/gitlab.yaml new file mode 100644 index 00000000000000..22d7ebf73ac2aa --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/gitlab.yaml @@ -0,0 +1,38 @@ +identity: + author: Leo.Wang + name: gitlab + label: + en_US: GitLab + zh_Hans: GitLab + description: + en_US: GitLab plugin, API v4 only. + zh_Hans: 用于获取GitLab内容的插件,目前仅支持 API v4。 + icon: gitlab.svg +credentials_for_provider: + access_tokens: + type: secret-input + required: true + label: + en_US: GitLab access token + zh_Hans: GitLab access token + placeholder: + en_US: Please input your GitLab access token + zh_Hans: 请输入你的 GitLab access token + help: + en_US: Get your GitLab access token from GitLab + zh_Hans: 从 GitLab 获取您的 access token + url: https://docs.gitlab.com/16.9/ee/api/oauth2.html + site_url: + type: text-input + required: false + default: 'https://gitlab.com' + label: + en_US: GitLab site url + zh_Hans: GitLab site url + placeholder: + en_US: Please input your GitLab site url + zh_Hans: 请输入你的 GitLab site url + help: + en_US: Find your GitLab url + zh_Hans: 找到你的 GitLab url + url: https://gitlab.com/help diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py new file mode 100644 index 00000000000000..880d722bda8e2f --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -0,0 +1,111 @@ +import json +from datetime import datetime, timedelta +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabCommitsTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + project = tool_parameters.get('project', '') + employee = tool_parameters.get('employee', '') + start_time = tool_parameters.get('start_time', '') + end_time = tool_parameters.get('end_time', '') + change_type = tool_parameters.get('change_type', 'all') + + if not project: + return self.create_text_message('Project is required') + + if not start_time: + start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() + if not end_time: + end_time = datetime.utcnow().isoformat() + + access_token = self.runtime.credentials.get('access_tokens') + site_url = self.runtime.credentials.get('site_url') + + if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + return self.create_text_message("Gitlab API Access Tokens is required.") + if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): + site_url = 'https://gitlab.com' + + # Get commit content + result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) + + return [self.create_json_message(item) for item in result] + + def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + # Get all of projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + + filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + + for project in filtered_projects: + project_id = project['id'] + project_name = project['name'] + print(f"Project: {project_name}") + + # Get all of proejct commits + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" + params = { + 'since': start_time, + 'until': end_time + } + if employee: + params['author'] = employee + + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + + for commit in commits: + commit_sha = commit['id'] + author_name = commit['author_name'] + + diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Caculate code lines of changed + added_lines = diff['diff'].count('\n+') + removed_lines = diff['diff'].count('\n-') + total_changes = added_lines + removed_lines + + if change_type == "new": + if added_lines > 1: + final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) + results.append({ + "commit_sha": commit_sha, + "author_name": author_name, + "diff": final_code + }) + else: + if total_changes > 1: + final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append({ + "commit_sha": commit_sha, + "author_name": author_name, + "diff": final_code_escaped + }) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml new file mode 100644 index 00000000000000..dd4e31d6633d37 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -0,0 +1,77 @@ +identity: + name: gitlab_commits + author: Leo.Wang + label: + en_US: GitLab Commits + zh_Hans: GitLab 提交内容查询 +description: + human: + en_US: A tool for query GitLab commits, Input should be a exists username or projec. + zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。 + llm: A tool for query GitLab commits, Input should be a exists username or project. +parameters: + - name: username + type: string + required: false + label: + en_US: username + zh_Hans: 员工用户名 + human_description: + en_US: username + zh_Hans: 员工用户名 + llm_description: User name for GitLab + form: llm + - name: project + type: string + required: true + label: + en_US: project + zh_Hans: 项目名 + human_description: + en_US: project + zh_Hans: 项目名 + llm_description: project for GitLab + form: llm + - name: start_time + type: string + required: false + label: + en_US: start_time + zh_Hans: 开始时间 + human_description: + en_US: start_time + zh_Hans: 开始时间 + llm_description: Start time for GitLab + form: llm + - name: end_time + type: string + required: false + label: + en_US: end_time + zh_Hans: 结束时间 + human_description: + en_US: end_time + zh_Hans: 结束时间 + llm_description: End time for GitLab + form: llm + - name: change_type + type: select + required: false + options: + - value: all + label: + en_US: all + zh_Hans: 所有 + - value: new + label: + en_US: new + zh_Hans: 新增 + default: all + label: + en_US: change_type + zh_Hans: 变更类型 + human_description: + en_US: change_type + zh_Hans: 变更类型 + llm_description: Content change type for GitLab + form: llm diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py new file mode 100644 index 00000000000000..7fa1d0d1124bab --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -0,0 +1,95 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabFilesTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + project = tool_parameters.get('project', '') + branch = tool_parameters.get('branch', '') + path = tool_parameters.get('path', '') + + + if not project: + return self.create_text_message('Project is required') + if not branch: + return self.create_text_message('Branch is required') + + if not path: + return self.create_text_message('Path is required') + + access_token = self.runtime.credentials.get('access_tokens') + site_url = self.runtime.credentials.get('site_url') + + if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + return self.create_text_message("Gitlab API Access Tokens is required.") + if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): + site_url = 'https://gitlab.com' + + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, project) + if not project_id: + return self.create_text_message(f"Project '{project}' not found.") + + # Get commit content + result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + + return [self.create_json_message(item) for item in result] + + def extract_project_name_and_path(self, path: str) -> tuple[str, str]: + parts = path.split('/', 1) + if len(parts) < 2: + return None, None + return parts[0], parts[1] + + def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: + headers = {"PRIVATE-TOKEN": access_token} + try: + url = f"{site_url}/api/v4/projects?search={project_name}" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + for project in projects: + if project['name'] == project_name: + return project['id'] + except requests.RequestException as e: + print(f"Error fetching project ID from GitLab: {e}") + return None + + def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + # List files and directories in the given path + url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + response = requests.get(url, headers=headers) + response.raise_for_status() + items = response.json() + + for item in items: + item_path = item['path'] + if item['type'] == 'tree': # It's a directory + results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) + else: # It's a file + file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + file_response = requests.get(file_url, headers=headers) + file_response.raise_for_status() + file_content = file_response.text + results.append({ + "path": item_path, + "branch": branch, + "content": file_content + }) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml new file mode 100644 index 00000000000000..d99b6254c1b99c --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -0,0 +1,45 @@ +identity: + name: gitlab_files + author: Leo.Wang + label: + en_US: GitLab Files + zh_Hans: GitLab 文件获取 +description: + human: + en_US: A tool for query GitLab files, Input should be branch and a exists file or directory path. + zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 + llm: A tool for query GitLab files, Input should be a exists file or directory path. +parameters: + - name: project + type: string + required: true + label: + en_US: project + zh_Hans: 项目 + human_description: + en_US: project + zh_Hans: 项目 + llm_description: Project for GitLab + form: llm + - name: branch + type: string + required: true + label: + en_US: branch + zh_Hans: 分支 + human_description: + en_US: branch + zh_Hans: 分支 + llm_description: Branch for GitLab + form: llm + - name: path + type: string + required: true + label: + en_US: path + zh_Hans: 文件路径 + human_description: + en_US: path + zh_Hans: 文件路径 + llm_description: File path for GitLab + form: llm diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py new file mode 100644 index 00000000000000..0d018e3ca27af5 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -0,0 +1,43 @@ +from typing import Any + +from core.helper import ssrf_proxy +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class JinaTokenizerTool(BuiltinTool): + _jina_tokenizer_endpoint = 'https://tokenize.jina.ai/' + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> ToolInvokeMessage: + content = tool_parameters['content'] + body = { + "content": content + } + + headers = { + 'Content-Type': 'application/json' + } + + if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): + headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + + if tool_parameters.get('return_chunks', False): + body['return_chunks'] = True + + if tool_parameters.get('return_tokens', False): + body['return_tokens'] = True + + if tokenizer := tool_parameters.get('tokenizer'): + body['tokenizer'] = tokenizer + + response = ssrf_proxy.post( + self._jina_tokenizer_endpoint, + headers=headers, + json=body, + ) + + return self.create_json_message(response.json()) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml new file mode 100644 index 00000000000000..62a5c7e7bacd75 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml @@ -0,0 +1,70 @@ +identity: + name: jina_tokenizer + author: hjlarry + label: + en_US: JinaTokenizer +description: + human: + en_US: Free API to tokenize text and segment long text into chunks. + zh_Hans: 免费的API可以将文本tokenize,也可以将长文本分割成多个部分。 + llm: Free API to tokenize text and segment long text into chunks. +parameters: + - name: content + type: string + required: true + label: + en_US: Content + zh_Hans: 内容 + llm_description: the content which need to tokenize or segment + form: llm + - name: return_tokens + type: boolean + required: false + label: + en_US: Return the tokens + zh_Hans: 是否返回tokens + human_description: + en_US: Return the tokens and their corresponding ids in the response. + zh_Hans: 返回tokens及其对应的ids。 + form: form + - name: return_chunks + type: boolean + label: + en_US: Return the chunks + zh_Hans: 是否分块 + human_description: + en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues. + zh_Hans: 将输入分块为具有语义意义的片段,同时根据常见的结构线索处理各种文本类型和边缘情况。 + form: form + - name: tokenizer + type: select + options: + - value: cl100k_base + label: + en_US: cl100k_base + - value: o200k_base + label: + en_US: o200k_base + - value: p50k_base + label: + en_US: p50k_base + - value: r50k_base + label: + en_US: r50k_base + - value: p50k_edit + label: + en_US: p50k_edit + - value: gpt2 + label: + en_US: gpt2 + label: + en_US: Tokenizer + human_description: + en_US: | + · cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5 + · o200k_base --- gpt-4o, gpt-4o-mini + · p50k_base --- text-davinci-003, text-davinci-002 + · r50k_base --- text-davinci-001, text-curie-001 + · p50k_edit --- text-davinci-edit-001, code-davinci-edit-001 + · gpt2 --- gpt-2 + form: form diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 27e34f1ff3fda9..48d1bdcab48885 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -36,21 +36,26 @@ def _invoke(self, # get create path create_path = tool_parameters.get('create_path', False) + # get value decode. + # if true, it will be decoded to an dict + value_decode = tool_parameters.get('value_decode', False) + ensure_ascii = tool_parameters.get('ensure_ascii', True) try: - result = self._insert(content, query, new_value, ensure_ascii, index, create_path) + result = self._insert(content, query, new_value, ensure_ascii, value_decode, index, create_path) return self.create_text_message(str(result)) except Exception: return self.create_text_message('Failed to insert JSON content') - def _insert(self, origin_json, query, new_value, ensure_ascii: bool, index=None, create_path=False): + def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False): try: input_data = json.loads(origin_json) expr = parse(query) - try: - new_value = json.loads(new_value) - except json.JSONDecodeError: - new_value = new_value + if value_decode is True: + try: + new_value = json.loads(new_value) + except json.JSONDecodeError: + return "Cannot decode new value to json object" matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.yaml b/api/core/tools/provider/builtin/json_process/tools/insert.yaml index 63e7816455cb68..21b51312dab6b3 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.yaml +++ b/api/core/tools/provider/builtin/json_process/tools/insert.yaml @@ -47,10 +47,22 @@ parameters: pt_BR: New Value human_description: en_US: New Value - zh_Hans: 新值 + zh_Hans: 插入的新值 pt_BR: New Value llm_description: New Value to insert form: llm + - name: value_decode + type: boolean + default: false + label: + en_US: Decode Value + zh_Hans: 解码值 + pt_BR: Decode Value + human_description: + en_US: Whether to decode the value to a JSON object + zh_Hans: 是否将值解码为 JSON 对象 + pt_BR: Whether to decode the value to a JSON object + form: form - name: create_path type: select required: true diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index be696bce0e0a2c..b19198aa938942 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -35,6 +35,10 @@ def _invoke(self, if not replace_model: return self.create_text_message('Invalid parameter replace_model') + # get value decode. + # if true, it will be decoded to an dict + value_decode = tool_parameters.get('value_decode', False) + ensure_ascii = tool_parameters.get('ensure_ascii', True) try: if replace_model == 'pattern': @@ -42,17 +46,17 @@ def _invoke(self, replace_pattern = tool_parameters.get('replace_pattern', '') if not replace_pattern: return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii) + result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii, value_decode) elif replace_model == 'key': result = self._replace_key(content, query, replace_value, ensure_ascii) elif replace_model == 'value': - result = self._replace_value(content, query, replace_value, ensure_ascii) + result = self._replace_value(content, query, replace_value, ensure_ascii, value_decode) return self.create_text_message(str(result)) except Exception: return self.create_text_message('Failed to replace JSON content') # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool) -> str: + def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: try: input_data = json.loads(content) expr = parse(query) @@ -61,6 +65,12 @@ def _replace_pattern(self, content: str, query: str, replace_pattern: str, repla for match in matches: new_value = match.value.replace(replace_pattern, replace_value) + if value_decode is True: + try: + new_value = json.loads(new_value) + except json.JSONDecodeError: + return "Cannot decode replace value to json object" + match.full_path.update(input_data, new_value) return json.dumps(input_data, ensure_ascii=ensure_ascii) @@ -92,10 +102,15 @@ def _replace_key(self, content: str, query: str, replace_value: str, ensure_asci return str(e) # Replace value - def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str: + def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: try: input_data = json.loads(content) expr = parse(query) + if value_decode is True: + try: + replace_value = json.loads(replace_value) + except json.JSONDecodeError: + return "Cannot decode replace value to json object" matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.yaml b/api/core/tools/provider/builtin/json_process/tools/replace.yaml index cf4b1dc63f45b8..ae238b1fbcd05e 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.yaml +++ b/api/core/tools/provider/builtin/json_process/tools/replace.yaml @@ -60,10 +60,22 @@ parameters: pt_BR: Replace Value human_description: en_US: New Value - zh_Hans: New Value + zh_Hans: 新值 pt_BR: New Value llm_description: New Value to replace form: llm + - name: value_decode + type: boolean + default: false + label: + en_US: Decode Value + zh_Hans: 解码值 + pt_BR: Decode Value + human_description: + en_US: Whether to decode the value to a JSON object (Does not apply to replace key) + zh_Hans: 是否将值解码为 JSON 对象 (不适用于键替换) + pt_BR: Whether to decode the value to a JSON object (Does not apply to replace key) + form: form - name: replace_model type: select required: true diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py new file mode 100644 index 00000000000000..b753be47917559 --- /dev/null +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -0,0 +1,73 @@ +from novita_client import ( + Txt2ImgV3Embedding, + Txt2ImgV3HiresFix, + Txt2ImgV3LoRA, + Txt2ImgV3Refiner, + V3TaskImage, +) + + +class NovitaAiToolBase: + def _extract_loras(self, loras_str: str): + if not loras_str: + return [] + + loras_ori_list = lora_str.strip().split(';') + result_list = [] + for lora_str in loras_ori_list: + lora_info = lora_str.strip().split(',') + lora = Txt2ImgV3LoRA( + model_name=lora_info[0].strip(), + strength=float(lora_info[1]), + ) + result_list.append(lora) + + return result_list + + def _extract_embeddings(self, embeddings_str: str): + if not embeddings_str: + return [] + + embeddings_ori_list = embeddings_str.strip().split(';') + result_list = [] + for embedding_str in embeddings_ori_list: + embedding = Txt2ImgV3Embedding( + model_name=embedding_str.strip() + ) + result_list.append(embedding) + + return result_list + + def _extract_hires_fix(self, hires_fix_str: str): + hires_fix_info = hires_fix_str.strip().split(',') + if 'upscaler' in hires_fix_info: + hires_fix = Txt2ImgV3HiresFix( + target_width=int(hires_fix_info[0]), + target_height=int(hires_fix_info[1]), + strength=float(hires_fix_info[2]), + upscaler=hires_fix_info[3].strip() + ) + else: + hires_fix = Txt2ImgV3HiresFix( + target_width=int(hires_fix_info[0]), + target_height=int(hires_fix_info[1]), + strength=float(hires_fix_info[2]) + ) + + return hires_fix + + def _extract_refiner(self, switch_at: str): + refiner = Txt2ImgV3Refiner( + switch_at=float(switch_at) + ) + return refiner + + def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: + """ + is hit nsfw + """ + if image.nsfw_detection_result is None: + return False + if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold: + return True + return False diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index c9524d6a66f4d8..5fef3d2da71bb9 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -4,19 +4,15 @@ from novita_client import ( NovitaClient, - Txt2ImgV3Embedding, - Txt2ImgV3HiresFix, - Txt2ImgV3LoRA, - Txt2ImgV3Refiner, - V3TaskImage, ) from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase from core.tools.tool.builtin_tool import BuiltinTool -class NovitaAiTxt2ImgTool(BuiltinTool): +class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): def _invoke(self, user_id: str, tool_parameters: dict[str, Any], @@ -73,65 +69,19 @@ def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: # process loras if 'loras' in res_parameters: - loras_ori_list = res_parameters.get('loras').strip().split(';') - locals_list = [] - for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') - lora = Txt2ImgV3LoRA( - model_name=lora_info[0].strip(), - strength=float(lora_info[1]), - ) - locals_list.append(lora) - - res_parameters['loras'] = locals_list + res_parameters['loras'] = self._extract_loras(res_parameters.get('loras')) # process embeddings if 'embeddings' in res_parameters: - embeddings_ori_list = res_parameters.get('embeddings').strip().split(';') - locals_list = [] - for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) - locals_list.append(embedding) - - res_parameters['embeddings'] = locals_list + res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings')) # process hires_fix if 'hires_fix' in res_parameters: - hires_fix_ori = res_parameters.get('hires_fix') - hires_fix_info = hires_fix_ori.strip().split(',') - if 'upscaler' in hires_fix_info: - hires_fix = Txt2ImgV3HiresFix( - target_width=int(hires_fix_info[0]), - target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() - ) - else: - hires_fix = Txt2ImgV3HiresFix( - target_width=int(hires_fix_info[0]), - target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) - ) - - res_parameters['hires_fix'] = hires_fix - - if 'refiner_switch_at' in res_parameters: - refiner = Txt2ImgV3Refiner( - switch_at=float(res_parameters.get('refiner_switch_at')) - ) - del res_parameters['refiner_switch_at'] - res_parameters['refiner'] = refiner + res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix')) - return res_parameters + # process refiner + if 'refiner_switch_at' in res_parameters: + res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at')) + del res_parameters['refiner_switch_at'] - def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: - """ - is hit nsfw - """ - if image.nsfw_detection_result is None: - return False - if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold: - return True - return False + return res_parameters diff --git a/api/core/tools/provider/builtin/onebot/_assets/icon.ico b/api/core/tools/provider/builtin/onebot/_assets/icon.ico new file mode 100644 index 00000000000000..1b07e965b9910b Binary files /dev/null and b/api/core/tools/provider/builtin/onebot/_assets/icon.ico differ diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py new file mode 100644 index 00000000000000..42f321e919745e --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -0,0 +1,12 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class OneBotProvider(BuiltinToolProviderController): + + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + + if not credentials.get("ob11_http_url"): + raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') diff --git a/api/core/tools/provider/builtin/onebot/onebot.yaml b/api/core/tools/provider/builtin/onebot/onebot.yaml new file mode 100644 index 00000000000000..1922adc4de4d56 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.yaml @@ -0,0 +1,35 @@ +identity: + author: RockChinQ + name: onebot + label: + en_US: OneBot v11 Protocol + zh_Hans: OneBot v11 协议 + description: + en_US: Unofficial OneBot v11 Protocol Tool + zh_Hans: 非官方 OneBot v11 协议工具 + icon: icon.ico +credentials_for_provider: + ob11_http_url: + type: text-input + required: true + label: + en_US: HTTP URL + zh_Hans: HTTP URL + description: + en_US: Forward HTTP URL of OneBot v11 + zh_Hans: OneBot v11 正向 HTTP URL + help: + en_US: Fill this with the HTTP URL of your OneBot server + zh_Hans: 请在你的 OneBot 协议端开启 正向 HTTP 并填写其 URL + access_token: + type: secret-input + required: false + label: + en_US: Access Token + zh_Hans: 访问令牌 + description: + en_US: Access Token for OneBot v11 Protocol + zh_Hans: OneBot 协议访问令牌 + help: + en_US: Fill this if you set a access token in your OneBot server + zh_Hans: 如果你在 OneBot 服务器中设置了 access token,请填写此项 diff --git a/api/core/tools/provider/builtin/onebot/tools/__init__.py b/api/core/tools/provider/builtin/onebot/tools/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py new file mode 100644 index 00000000000000..2a1a9f86de5b44 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -0,0 +1,64 @@ +from typing import Any, Union + +import requests +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendGroupMsg(BuiltinTool): + """OneBot v11 Tool: Send Group Message""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + # Get parameters + send_group_id = tool_parameters.get('group_id', '') + + message = tool_parameters.get('message', '') + if not message: + return self.create_json_message( + { + 'error': 'Message is empty.' + } + ) + + auto_escape = tool_parameters.get('auto_escape', False) + + try: + url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' + + resp = requests.post( + url, + json={ + 'group_id': send_group_id, + 'message': message, + 'auto_escape': auto_escape + }, + headers={ + 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] + } + ) + + if resp.status_code != 200: + return self.create_json_message( + { + 'error': f'Failed to send group message: {resp.text}' + } + ) + + return self.create_json_message( + { + 'response': resp.json() + } + ) + except Exception as e: + return self.create_json_message( + { + 'error': f'Failed to send group message: {e}' + } + ) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml new file mode 100644 index 00000000000000..64beaa85457a3a --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_group_msg + author: RockChinQ + label: + en_US: Send Group Message + zh_Hans: 发送群消息 +description: + human: + en_US: Send a message to a group + zh_Hans: 发送消息到群聊 + llm: A tool for sending a message segment to a group +parameters: + - name: group_id + type: number + required: true + label: + en_US: Target Group ID + zh_Hans: 目标群 ID + human_description: + en_US: The group ID of the target group + zh_Hans: 目标群的群 ID + llm_description: The group ID of the target group + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py new file mode 100644 index 00000000000000..8ef4d72ab63b44 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -0,0 +1,64 @@ +from typing import Any, Union + +import requests +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendPrivateMsg(BuiltinTool): + """OneBot v11 Tool: Send Private Message""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + # Get parameters + send_user_id = tool_parameters.get('user_id', '') + + message = tool_parameters.get('message', '') + if not message: + return self.create_json_message( + { + 'error': 'Message is empty.' + } + ) + + auto_escape = tool_parameters.get('auto_escape', False) + + try: + url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + + resp = requests.post( + url, + json={ + 'user_id': send_user_id, + 'message': message, + 'auto_escape': auto_escape + }, + headers={ + 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] + } + ) + + if resp.status_code != 200: + return self.create_json_message( + { + 'error': f'Failed to send private message: {resp.text}' + } + ) + + return self.create_json_message( + { + 'response': resp.json() + } + ) + except Exception as e: + return self.create_json_message( + { + 'error': f'Failed to send private message: {e}' + } + ) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml new file mode 100644 index 00000000000000..8200ce4a83f4e2 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_private_msg + author: RockChinQ + label: + en_US: Send Private Message + zh_Hans: 发送私聊消息 +description: + human: + en_US: Send a private message to a user + zh_Hans: 发送私聊消息给用户 + llm: A tool for sending a message segment to a user in private chat +parameters: + - name: user_id + type: number + required: true + label: + en_US: Target User ID + zh_Hans: 目标用户 ID + human_description: + en_US: The user ID of the target user + zh_Hans: 目标用户的用户 ID + llm_description: The user ID of the target user + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/perplexity/_assets/icon.svg b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg new file mode 100644 index 00000000000000..c2974c142fc622 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py new file mode 100644 index 00000000000000..ff91edf18df4c0 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -0,0 +1,46 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PerplexityProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + headers = { + "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", + "Content-Type": "application/json" + } + + payload = { + "model": "llama-3.1-sonar-small-128k-online", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ], + "max_tokens": 5, + "temperature": 0.1, + "top_p": 0.9, + "stream": False + } + + try: + response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + except requests.RequestException as e: + raise ToolProviderCredentialValidationError( + f"Failed to validate Perplexity API key: {str(e)}" + ) + + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + f"Perplexity API key is invalid. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.yaml b/api/core/tools/provider/builtin/perplexity/perplexity.yaml new file mode 100644 index 00000000000000..c0b504f300c45a --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.yaml @@ -0,0 +1,26 @@ +identity: + author: Dify + name: perplexity + label: + en_US: Perplexity + zh_Hans: Perplexity + description: + en_US: Perplexity.AI + zh_Hans: Perplexity.AI + icon: icon.svg + tags: + - search +credentials_for_provider: + perplexity_api_key: + type: secret-input + required: true + label: + en_US: Perplexity API key + zh_Hans: Perplexity API key + placeholder: + en_US: Please input your Perplexity API key + zh_Hans: 请输入你的 Perplexity API key + help: + en_US: Get your Perplexity API key from Perplexity + zh_Hans: 从 Perplexity 获取您的 Perplexity API key + url: https://www.perplexity.ai/settings/api diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py new file mode 100644 index 00000000000000..5b1a263f9b2218 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -0,0 +1,72 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + +class PerplexityAITool(BuiltinTool): + def _parse_response(self, response: dict) -> dict: + """Parse the response from Perplexity AI API""" + if 'choices' in response and len(response['choices']) > 0: + message = response['choices'][0]['message'] + return { + 'content': message.get('content', ''), + 'role': message.get('role', ''), + 'citations': response.get('citations', []) + } + else: + return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []} + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", + "Content-Type": "application/json" + } + + payload = { + "model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'), + "messages": [ + { + "role": "system", + "content": "Be precise and concise." + }, + { + "role": "user", + "content": tool_parameters['query'] + } + ], + "max_tokens": tool_parameters.get('max_tokens', 4096), + "temperature": tool_parameters.get('temperature', 0.7), + "top_p": tool_parameters.get('top_p', 1), + "top_k": tool_parameters.get('top_k', 5), + "presence_penalty": tool_parameters.get('presence_penalty', 0), + "frequency_penalty": tool_parameters.get('frequency_penalty', 1), + "stream": False + } + + if 'search_recency_filter' in tool_parameters: + payload['search_recency_filter'] = tool_parameters['search_recency_filter'] + if 'return_citations' in tool_parameters: + payload['return_citations'] = tool_parameters['return_citations'] + if 'search_domain_filter' in tool_parameters: + if isinstance(tool_parameters['search_domain_filter'], str): + payload['search_domain_filter'] = [tool_parameters['search_domain_filter']] + elif isinstance(tool_parameters['search_domain_filter'], list): + payload['search_domain_filter'] = tool_parameters['search_domain_filter'] + + + response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + valuable_res = self._parse_response(response.json()) + + return [ + self.create_json_message(valuable_res), + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)) + ] diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml new file mode 100644 index 00000000000000..02a645df335aaf --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml @@ -0,0 +1,178 @@ +identity: + name: perplexity + author: Dify + label: + en_US: Perplexity Search +description: + human: + en_US: Search information using Perplexity AI's language models. + llm: This tool is used to search information using Perplexity AI's language models. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 查询 + human_description: + en_US: The text query to be processed by the AI model. + zh_Hans: 要由 AI 模型处理的文本查询。 + form: llm + - name: model + type: select + required: false + label: + en_US: Model Name + zh_Hans: 模型名称 + human_description: + en_US: The Perplexity AI model to use for generating the response. + zh_Hans: 用于生成响应的 Perplexity AI 模型。 + form: form + default: "llama-3.1-sonar-small-128k-online" + options: + - value: llama-3.1-sonar-small-128k-online + label: + en_US: llama-3.1-sonar-small-128k-online + zh_Hans: llama-3.1-sonar-small-128k-online + - value: llama-3.1-sonar-large-128k-online + label: + en_US: llama-3.1-sonar-large-128k-online + zh_Hans: llama-3.1-sonar-large-128k-online + - value: llama-3.1-sonar-huge-128k-online + label: + en_US: llama-3.1-sonar-huge-128k-online + zh_Hans: llama-3.1-sonar-huge-128k-online + - name: max_tokens + type: number + required: false + label: + en_US: Max Tokens + zh_Hans: 最大令牌数 + pt_BR: Máximo de Tokens + human_description: + en_US: The maximum number of tokens to generate in the response. + zh_Hans: 在响应中生成的最大令牌数。 + pt_BR: O número máximo de tokens a serem gerados na resposta. + form: form + default: 4096 + min: 1 + max: 4096 + - name: temperature + type: number + required: false + label: + en_US: Temperature + zh_Hans: 温度 + pt_BR: Temperatura + human_description: + en_US: Controls randomness in the output. Lower values make the output more focused and deterministic. + zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。 + form: form + default: 0.7 + min: 0 + max: 1 + - name: top_k + type: number + required: false + label: + en_US: Top K + zh_Hans: 取样数量 + human_description: + en_US: The number of top results to consider for response generation. + zh_Hans: 用于生成响应的顶部结果数量。 + form: form + default: 5 + min: 1 + max: 100 + - name: top_p + type: number + required: false + label: + en_US: Top P + zh_Hans: Top P + human_description: + en_US: Controls diversity via nucleus sampling. + zh_Hans: 通过核心采样控制多样性。 + form: form + default: 1 + min: 0.1 + max: 1 + step: 0.1 + - name: presence_penalty + type: number + required: false + label: + en_US: Presence Penalty + zh_Hans: 存在惩罚 + human_description: + en_US: Positive values penalize new tokens based on whether they appear in the text so far. + zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。 + form: form + default: 0 + min: -1.0 + max: 1.0 + step: 0.1 + - name: frequency_penalty + type: number + required: false + label: + en_US: Frequency Penalty + zh_Hans: 频率惩罚 + human_description: + en_US: Positive values penalize new tokens based on their existing frequency in the text so far. + zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。 + form: form + default: 1 + min: 0.1 + max: 1.0 + step: 0.1 + - name: return_citations + type: boolean + required: false + label: + en_US: Return Citations + zh_Hans: 返回引用 + human_description: + en_US: Whether to return citations in the response. + zh_Hans: 是否在响应中返回引用。 + form: form + default: true + - name: search_domain_filter + type: string + required: false + label: + en_US: Search Domain Filter + zh_Hans: 搜索域过滤器 + human_description: + en_US: Domain to filter the search results. + zh_Hans: 用于过滤搜索结果的域名。 + form: form + default: "" + - name: search_recency_filter + type: select + required: false + label: + en_US: Search Recency Filter + zh_Hans: 搜索时间过滤器 + human_description: + en_US: Filter for search results based on recency. + zh_Hans: 基于时间筛选搜索结果。 + form: form + default: "month" + options: + - value: day + label: + en_US: Day + zh_Hans: 天 + - value: week + label: + en_US: Week + zh_Hans: 周 + - value: month + label: + en_US: Month + zh_Hans: 月 + - value: year + label: + en_US: Year + zh_Hans: 年 diff --git a/api/core/tools/provider/builtin/regex/_assets/icon.svg b/api/core/tools/provider/builtin/regex/_assets/icon.svg new file mode 100644 index 00000000000000..0231a2b4aa9da2 --- /dev/null +++ b/api/core/tools/provider/builtin/regex/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/regex/regex.py b/api/core/tools/provider/builtin/regex/regex.py new file mode 100644 index 00000000000000..d38ae1b292675f --- /dev/null +++ b/api/core/tools/provider/builtin/regex/regex.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.regex.tools.regex_extract import RegexExpressionTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class RegexProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + RegexExpressionTool().invoke( + user_id='', + tool_parameters={ + 'content': '1+(2+3)*4', + 'expression': r'(\d+)', + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/regex/regex.yaml b/api/core/tools/provider/builtin/regex/regex.yaml new file mode 100644 index 00000000000000..d05776f214e8d2 --- /dev/null +++ b/api/core/tools/provider/builtin/regex/regex.yaml @@ -0,0 +1,15 @@ +identity: + author: zhuhao + name: regex + label: + en_US: Regex + zh_Hans: 正则表达式提取 + pt_BR: Regex + description: + en_US: A tool for regex extraction. + zh_Hans: 一个用于正则表达式内容提取的工具。 + pt_BR: A tool for regex extraction. + icon: icon.svg + tags: + - utilities + - productivity diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.py b/api/core/tools/provider/builtin/regex/tools/regex_extract.py new file mode 100644 index 00000000000000..5d8f013d0d012c --- /dev/null +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.py @@ -0,0 +1,27 @@ +import re +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class RegexExpressionTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + # get expression + content = tool_parameters.get('content', '').strip() + if not content: + return self.create_text_message('Invalid content') + expression = tool_parameters.get('expression', '').strip() + if not expression: + return self.create_text_message('Invalid expression') + try: + result = re.findall(expression, content) + return self.create_text_message(str(result)) + except Exception as e: + return self.create_text_message(f'Failed to extract result, error: {str(e)}') \ No newline at end of file diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.yaml b/api/core/tools/provider/builtin/regex/tools/regex_extract.yaml new file mode 100644 index 00000000000000..de4100def176c9 --- /dev/null +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.yaml @@ -0,0 +1,38 @@ +identity: + name: regex_extract + author: zhuhao + label: + en_US: Regex Extract + zh_Hans: 正则表达式内容提取 + pt_BR: Regex Extract +description: + human: + en_US: A tool for extracting matching content using regular expressions. + zh_Hans: 一个用于利用正则表达式提取匹配内容结果的工具。 + pt_BR: A tool for extracting matching content using regular expressions. + llm: A tool for extracting matching content using regular expressions. +parameters: + - name: content + type: string + required: true + label: + en_US: Content to be extracted + zh_Hans: 内容 + pt_BR: Content to be extracted + human_description: + en_US: Content to be extracted + zh_Hans: 内容 + pt_BR: Content to be extracted + form: llm + - name: expression + type: string + required: true + label: + en_US: Regular expression + zh_Hans: 正则表达式 + pt_BR: Regular expression + human_description: + en_US: Regular expression + zh_Hans: 正则表达式 + pt_BR: Regular expression + form: llm diff --git a/api/core/tools/provider/builtin/searxng/docker/settings.yml b/api/core/tools/provider/builtin/searxng/docker/settings.yml new file mode 100644 index 00000000000000..18e18688002cbc --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/docker/settings.yml @@ -0,0 +1,2501 @@ +general: + # Debug mode, only for development. Is overwritten by ${SEARXNG_DEBUG} + debug: false + # displayed name + instance_name: "searxng" + # For example: https://example.com/privacy + privacypolicy_url: false + # use true to use your own donation page written in searx/info/en/donate.md + # use false to disable the donation link + donation_url: false + # mailto:contact@example.com + contact_url: false + # record stats + enable_metrics: true + +brand: + new_issue_url: https://github.com/searxng/searxng/issues/new + docs_url: https://docs.searxng.org/ + public_instances: https://searx.space + wiki_url: https://github.com/searxng/searxng/wiki + issue_url: https://github.com/searxng/searxng/issues + # custom: + # maintainer: "Jon Doe" + # # Custom entries in the footer: [title]: [link] + # links: + # Uptime: https://uptime.searxng.org/history/darmarit-org + # About: "https://searxng.org" + +search: + # Filter results. 0: None, 1: Moderate, 2: Strict + safe_search: 0 + # Existing autocomplete backends: "dbpedia", "duckduckgo", "google", "yandex", "mwmbl", + # "seznam", "startpage", "stract", "swisscows", "qwant", "wikipedia" - leave blank to turn it off + # by default. + autocomplete: "" + # minimun characters to type before autocompleter starts + autocomplete_min: 4 + # Default search language - leave blank to detect from browser information or + # use codes from 'languages.py' + default_lang: "auto" + # max_page: 0 # if engine supports paging, 0 means unlimited numbers of pages + # Available languages + # languages: + # - all + # - en + # - en-US + # - de + # - it-IT + # - fr + # - fr-BE + # ban time in seconds after engine errors + ban_time_on_fail: 5 + # max ban time in seconds after engine errors + max_ban_time_on_fail: 120 + suspended_times: + # Engine suspension time after error (in seconds; set to 0 to disable) + # For error "Access denied" and "HTTP error [402, 403]" + SearxEngineAccessDenied: 86400 + # For error "CAPTCHA" + SearxEngineCaptcha: 86400 + # For error "Too many request" and "HTTP error 429" + SearxEngineTooManyRequests: 3600 + # Cloudflare CAPTCHA + cf_SearxEngineCaptcha: 1296000 + cf_SearxEngineAccessDenied: 86400 + # ReCAPTCHA + recaptcha_SearxEngineCaptcha: 604800 + + # remove format to deny access, use lower case. + # formats: [html, csv, json, rss] + formats: + - html + - json + +server: + # Is overwritten by ${SEARXNG_PORT} and ${SEARXNG_BIND_ADDRESS} + port: 8888 + bind_address: "127.0.0.1" + # public URL of the instance, to ensure correct inbound links. Is overwritten + # by ${SEARXNG_URL}. + base_url: http://0.0.0.0:8081/ # "http://example.com/location" + # rate limit the number of request on the instance, block some bots. + # Is overwritten by ${SEARXNG_LIMITER} + limiter: false + # enable features designed only for public instances. + # Is overwritten by ${SEARXNG_PUBLIC_INSTANCE} + public_instance: false + + # If your instance owns a /etc/searxng/settings.yml file, then set the following + # values there. + + secret_key: "772ba36386fb56d0f8fe818941552dabbe69220d4c0eb4a385a5729cdbc20c2d" # Is overwritten by ${SEARXNG_SECRET} + # Proxy image results through SearXNG. Is overwritten by ${SEARXNG_IMAGE_PROXY} + image_proxy: false + # 1.0 and 1.1 are supported + http_protocol_version: "1.0" + # POST queries are more secure as they don't show up in history but may cause + # problems when using Firefox containers + method: "POST" + default_http_headers: + X-Content-Type-Options: nosniff + X-Download-Options: noopen + X-Robots-Tag: noindex, nofollow + Referrer-Policy: no-referrer + +redis: + # URL to connect redis database. Is overwritten by ${SEARXNG_REDIS_URL}. + # https://docs.searxng.org/admin/settings/settings_redis.html#settings-redis + url: false + +ui: + # Custom static path - leave it blank if you didn't change + static_path: "" + # Is overwritten by ${SEARXNG_STATIC_USE_HASH}. + static_use_hash: false + # Custom templates path - leave it blank if you didn't change + templates_path: "" + # query_in_title: When true, the result page's titles contains the query + # it decreases the privacy, since the browser can records the page titles. + query_in_title: false + # infinite_scroll: When true, automatically loads the next page when scrolling to bottom of the current page. + infinite_scroll: false + # ui theme + default_theme: simple + # center the results ? + center_alignment: false + # URL prefix of the internet archive, don't forget trailing slash (if needed). + # cache_url: "https://webcache.googleusercontent.com/search?q=cache:" + # Default interface locale - leave blank to detect from browser information or + # use codes from the 'locales' config section + default_locale: "" + # Open result links in a new tab by default + # results_on_new_tab: false + theme_args: + # style of simple theme: auto, light, dark + simple_style: auto + # Perform search immediately if a category selected. + # Disable to select multiple categories at once and start the search manually. + search_on_category_select: true + # Hotkeys: default or vim + hotkeys: default + +# Lock arbitrary settings on the preferences page. To find the ID of the user +# setting you want to lock, check the ID of the form on the page "preferences". +# +# preferences: +# lock: +# - language +# - autocomplete +# - method +# - query_in_title + +# searx supports result proxification using an external service: +# https://github.com/asciimoo/morty uncomment below section if you have running +# morty proxy the key is base64 encoded (keep the !!binary notation) +# Note: since commit af77ec3, morty accepts a base64 encoded key. +# +# result_proxy: +# url: http://127.0.0.1:3000/ +# # the key is a base64 encoded string, the YAML !!binary prefix is optional +# key: !!binary "your_morty_proxy_key" +# # [true|false] enable the "proxy" button next to each result +# proxify_results: true + +# communication with search engines +# +outgoing: + # default timeout in seconds, can be override by engine + request_timeout: 3.0 + # the maximum timeout in seconds + # max_request_timeout: 10.0 + # suffix of searx_useragent, could contain information like an email address + # to the administrator + useragent_suffix: "" + # The maximum number of concurrent connections that may be established. + pool_connections: 100 + # Allow the connection pool to maintain keep-alive connections below this + # point. + pool_maxsize: 20 + # See https://www.python-httpx.org/http2/ + enable_http2: true + # uncomment below section if you want to use a custom server certificate + # see https://www.python-httpx.org/advanced/#changing-the-verification-defaults + # and https://www.python-httpx.org/compatibility/#ssl-configuration + # verify: ~/.mitmproxy/mitmproxy-ca-cert.cer + # + # uncomment below section if you want to use a proxyq see: SOCKS proxies + # https://2.python-requests.org/en/latest/user/advanced/#proxies + # are also supported: see + # https://2.python-requests.org/en/latest/user/advanced/#socks + # + # proxies: + # all://: + # - http://host.docker.internal:1080 + # + # using_tor_proxy: true + # + # Extra seconds to add in order to account for the time taken by the proxy + # + # extra_proxy_timeout: 10 + # + # uncomment below section only if you have more than one network interface + # which can be the source of outgoing search requests + # + # source_ips: + # - 1.1.1.1 + # - 1.1.1.2 + # - fe80::/126 + +# External plugin configuration, for more details see +# https://docs.searxng.org/dev/plugins.html +# +# plugins: +# - plugin1 +# - plugin2 +# - ... + +# Comment or un-comment plugin to activate / deactivate by default. +# +# enabled_plugins: +# # these plugins are enabled if nothing is configured .. +# - 'Hash plugin' +# - 'Self Information' +# - 'Tracker URL remover' +# - 'Ahmia blacklist' # activation depends on outgoing.using_tor_proxy +# # these plugins are disabled if nothing is configured .. +# - 'Hostnames plugin' # see 'hostnames' configuration below +# - 'Basic Calculator' +# - 'Open Access DOI rewrite' +# - 'Tor check plugin' +# # Read the docs before activate: auto-detection of the language could be +# # detrimental to users expectations / users can activate the plugin in the +# # preferences if they want. +# - 'Autodetect search language' + +# Configuration of the "Hostnames plugin": +# +# hostnames: +# replace: +# '(.*\.)?youtube\.com$': 'invidious.example.com' +# '(.*\.)?youtu\.be$': 'invidious.example.com' +# '(.*\.)?reddit\.com$': 'teddit.example.com' +# '(.*\.)?redd\.it$': 'teddit.example.com' +# '(www\.)?twitter\.com$': 'nitter.example.com' +# remove: +# - '(.*\.)?facebook.com$' +# low_priority: +# - '(.*\.)?google(\..*)?$' +# high_priority: +# - '(.*\.)?wikipedia.org$' +# +# Alternatively you can use external files for configuring the "Hostnames plugin": +# +# hostnames: +# replace: 'rewrite-hosts.yml' +# +# Content of 'rewrite-hosts.yml' (place the file in the same directory as 'settings.yml'): +# '(.*\.)?youtube\.com$': 'invidious.example.com' +# '(.*\.)?youtu\.be$': 'invidious.example.com' +# + +checker: + # disable checker when in debug mode + off_when_debug: true + + # use "scheduling: false" to disable scheduling + # scheduling: interval or int + + # to activate the scheduler: + # * uncomment "scheduling" section + # * add "cache2 = name=searxngcache,items=2000,blocks=2000,blocksize=4096,bitmap=1" + # to your uwsgi.ini + + # scheduling: + # start_after: [300, 1800] # delay to start the first run of the checker + # every: [86400, 90000] # how often the checker runs + + # additional tests: only for the YAML anchors (see the engines section) + # + additional_tests: + rosebud: &test_rosebud + matrix: + query: rosebud + lang: en + result_container: + - not_empty + - ['one_title_contains', 'citizen kane'] + test: + - unique_results + + android: &test_android + matrix: + query: ['android'] + lang: ['en', 'de', 'fr', 'zh-CN'] + result_container: + - not_empty + - ['one_title_contains', 'google'] + test: + - unique_results + + # tests: only for the YAML anchors (see the engines section) + tests: + infobox: &tests_infobox + infobox: + matrix: + query: ["linux", "new york", "bbc"] + result_container: + - has_infobox + +categories_as_tabs: + general: + images: + videos: + news: + map: + music: + it: + science: + files: + social media: + +engines: + - name: 9gag + engine: 9gag + shortcut: 9g + disabled: true + + - name: alpine linux packages + engine: alpinelinux + disabled: true + shortcut: alp + + - name: annas archive + engine: annas_archive + disabled: true + shortcut: aa + + # - name: annas articles + # engine: annas_archive + # shortcut: aaa + # # https://docs.searxng.org/dev/engines/online/annas_archive.html + # aa_content: 'magazine' # book_fiction, book_unknown, book_nonfiction, book_comic + # aa_ext: 'pdf' # pdf, epub, .. + # aa_sort: oldest' # newest, oldest, largest, smallest + + - name: apk mirror + engine: apkmirror + timeout: 4.0 + shortcut: apkm + disabled: true + + - name: apple app store + engine: apple_app_store + shortcut: aps + disabled: true + + # Requires Tor + - name: ahmia + engine: ahmia + categories: onions + enable_http: true + shortcut: ah + + - name: anaconda + engine: xpath + paging: true + first_page_num: 0 + search_url: https://anaconda.org/search?q={query}&page={pageno} + results_xpath: //tbody/tr + url_xpath: ./td/h5/a[last()]/@href + title_xpath: ./td/h5 + content_xpath: ./td[h5]/text() + categories: it + timeout: 6.0 + shortcut: conda + disabled: true + + - name: arch linux wiki + engine: archlinux + shortcut: al + + - name: artic + engine: artic + shortcut: arc + timeout: 4.0 + + - name: arxiv + engine: arxiv + shortcut: arx + timeout: 4.0 + + - name: ask + engine: ask + shortcut: ask + disabled: true + + # tmp suspended: dh key too small + # - name: base + # engine: base + # shortcut: bs + + - name: bandcamp + engine: bandcamp + shortcut: bc + categories: music + + - name: wikipedia + engine: wikipedia + shortcut: wp + # add "list" to the array to get results in the results list + display_type: ["infobox"] + base_url: 'https://{language}.wikipedia.org/' + categories: [general] + + - name: bilibili + engine: bilibili + shortcut: bil + disabled: true + + - name: bing + engine: bing + shortcut: bi + disabled: false + + - name: bing images + engine: bing_images + shortcut: bii + + - name: bing news + engine: bing_news + shortcut: bin + + - name: bing videos + engine: bing_videos + shortcut: biv + + - name: bitbucket + engine: xpath + paging: true + search_url: https://bitbucket.org/repo/all/{pageno}?name={query} + url_xpath: //article[@class="repo-summary"]//a[@class="repo-link"]/@href + title_xpath: //article[@class="repo-summary"]//a[@class="repo-link"] + content_xpath: //article[@class="repo-summary"]/p + categories: [it, repos] + timeout: 4.0 + disabled: true + shortcut: bb + about: + website: https://bitbucket.org/ + wikidata_id: Q2493781 + official_api_documentation: https://developer.atlassian.com/bitbucket + use_official_api: false + require_api_key: false + results: HTML + + - name: bpb + engine: bpb + shortcut: bpb + disabled: true + + - name: btdigg + engine: btdigg + shortcut: bt + disabled: true + + - name: openverse + engine: openverse + categories: images + shortcut: opv + + - name: media.ccc.de + engine: ccc_media + shortcut: c3tv + # We don't set language: de here because media.ccc.de is not just + # for a German audience. It contains many English videos and many + # German videos have English subtitles. + disabled: true + + - name: chefkoch + engine: chefkoch + shortcut: chef + # to show premium or plus results too: + # skip_premium: false + + # - name: core.ac.uk + # engine: core + # categories: science + # shortcut: cor + # # get your API key from: https://core.ac.uk/api-keys/register/ + # api_key: 'unset' + + - name: cppreference + engine: cppreference + shortcut: cpp + paging: false + disabled: true + + - name: crossref + engine: crossref + shortcut: cr + timeout: 30 + disabled: true + + - name: crowdview + engine: json_engine + shortcut: cv + categories: general + paging: false + search_url: https://crowdview-next-js.onrender.com/api/search-v3?query={query} + results_query: results + url_query: link + title_query: title + content_query: snippet + disabled: true + about: + website: https://crowdview.ai/ + + - name: yep + engine: yep + shortcut: yep + categories: general + search_type: web + timeout: 5 + disabled: true + + - name: yep images + engine: yep + shortcut: yepi + categories: images + search_type: images + disabled: true + + - name: yep news + engine: yep + shortcut: yepn + categories: news + search_type: news + disabled: true + + - name: curlie + engine: xpath + shortcut: cl + categories: general + disabled: true + paging: true + lang_all: '' + search_url: https://curlie.org/search?q={query}&lang={lang}&start={pageno}&stime=92452189 + page_size: 20 + results_xpath: //div[@id="site-list-content"]/div[@class="site-item"] + url_xpath: ./div[@class="title-and-desc"]/a/@href + title_xpath: ./div[@class="title-and-desc"]/a/div + content_xpath: ./div[@class="title-and-desc"]/div[@class="site-descr"] + about: + website: https://curlie.org/ + wikidata_id: Q60715723 + use_official_api: false + require_api_key: false + results: HTML + + - name: currency + engine: currency_convert + categories: general + shortcut: cc + + - name: bahnhof + engine: json_engine + search_url: https://www.bahnhof.de/api/stations/search/{query} + url_prefix: https://www.bahnhof.de/ + url_query: slug + title_query: name + content_query: state + shortcut: bf + disabled: true + about: + website: https://www.bahn.de + wikidata_id: Q22811603 + use_official_api: false + require_api_key: false + results: JSON + language: de + tests: + bahnhof: + matrix: + query: berlin + lang: en + result_container: + - not_empty + - ['one_title_contains', 'Berlin Hauptbahnhof'] + test: + - unique_results + + - name: deezer + engine: deezer + shortcut: dz + disabled: true + + - name: destatis + engine: destatis + shortcut: destat + disabled: true + + - name: deviantart + engine: deviantart + shortcut: da + timeout: 3.0 + + - name: ddg definitions + engine: duckduckgo_definitions + shortcut: ddd + weight: 2 + disabled: true + tests: *tests_infobox + + # cloudflare protected + # - name: digbt + # engine: digbt + # shortcut: dbt + # timeout: 6.0 + # disabled: true + + - name: docker hub + engine: docker_hub + shortcut: dh + categories: [it, packages] + + - name: encyclosearch + engine: json_engine + shortcut: es + categories: general + paging: true + search_url: https://encyclosearch.org/encyclosphere/search?q={query}&page={pageno}&resultsPerPage=15 + results_query: Results + url_query: SourceURL + title_query: Title + content_query: Description + disabled: true + about: + website: https://encyclosearch.org + official_api_documentation: https://encyclosearch.org/docs/#/rest-api + use_official_api: true + require_api_key: false + results: JSON + + - name: erowid + engine: xpath + paging: true + first_page_num: 0 + page_size: 30 + search_url: https://www.erowid.org/search.php?q={query}&s={pageno} + url_xpath: //dl[@class="results-list"]/dt[@class="result-title"]/a/@href + title_xpath: //dl[@class="results-list"]/dt[@class="result-title"]/a/text() + content_xpath: //dl[@class="results-list"]/dd[@class="result-details"] + categories: [] + shortcut: ew + disabled: true + about: + website: https://www.erowid.org/ + wikidata_id: Q1430691 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + # - name: elasticsearch + # shortcut: es + # engine: elasticsearch + # base_url: http://localhost:9200 + # username: elastic + # password: changeme + # index: my-index + # # available options: match, simple_query_string, term, terms, custom + # query_type: match + # # if query_type is set to custom, provide your query here + # #custom_query_json: {"query":{"match_all": {}}} + # #show_metadata: false + # disabled: true + + - name: wikidata + engine: wikidata + shortcut: wd + timeout: 3.0 + weight: 2 + # add "list" to the array to get results in the results list + display_type: ["infobox"] + tests: *tests_infobox + categories: [general] + + - name: duckduckgo + engine: duckduckgo + shortcut: ddg + + - name: duckduckgo images + engine: duckduckgo_extra + categories: [images, web] + ddg_category: images + shortcut: ddi + disabled: true + + - name: duckduckgo videos + engine: duckduckgo_extra + categories: [videos, web] + ddg_category: videos + shortcut: ddv + disabled: true + + - name: duckduckgo news + engine: duckduckgo_extra + categories: [news, web] + ddg_category: news + shortcut: ddn + disabled: true + + - name: duckduckgo weather + engine: duckduckgo_weather + shortcut: ddw + disabled: true + + - name: apple maps + engine: apple_maps + shortcut: apm + disabled: true + timeout: 5.0 + + - name: emojipedia + engine: emojipedia + timeout: 4.0 + shortcut: em + disabled: true + + - name: tineye + engine: tineye + shortcut: tin + timeout: 9.0 + disabled: true + + - name: etymonline + engine: xpath + paging: true + search_url: https://etymonline.com/search?page={pageno}&q={query} + url_xpath: //a[contains(@class, "word__name--")]/@href + title_xpath: //a[contains(@class, "word__name--")] + content_xpath: //section[contains(@class, "word__defination")] + first_page_num: 1 + shortcut: et + categories: [dictionaries] + about: + website: https://www.etymonline.com/ + wikidata_id: Q1188617 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + # - name: ebay + # engine: ebay + # shortcut: eb + # base_url: 'https://www.ebay.com' + # disabled: true + # timeout: 5 + + - name: 1x + engine: www1x + shortcut: 1x + timeout: 3.0 + disabled: true + + - name: fdroid + engine: fdroid + shortcut: fd + disabled: true + + - name: findthatmeme + engine: findthatmeme + shortcut: ftm + disabled: true + + - name: flickr + categories: images + shortcut: fl + # You can use the engine using the official stable API, but you need an API + # key, see: https://www.flickr.com/services/apps/create/ + # engine: flickr + # api_key: 'apikey' # required! + # Or you can use the html non-stable engine, activated by default + engine: flickr_noapi + + - name: free software directory + engine: mediawiki + shortcut: fsd + categories: [it, software wikis] + base_url: https://directory.fsf.org/ + search_type: title + timeout: 5.0 + disabled: true + about: + website: https://directory.fsf.org/ + wikidata_id: Q2470288 + + # - name: freesound + # engine: freesound + # shortcut: fnd + # disabled: true + # timeout: 15.0 + # API key required, see: https://freesound.org/docs/api/overview.html + # api_key: MyAPIkey + + - name: frinkiac + engine: frinkiac + shortcut: frk + disabled: true + + - name: fyyd + engine: fyyd + shortcut: fy + timeout: 8.0 + disabled: true + + - name: geizhals + engine: geizhals + shortcut: geiz + disabled: true + + - name: genius + engine: genius + shortcut: gen + + - name: gentoo + engine: mediawiki + shortcut: ge + categories: ["it", "software wikis"] + base_url: "https://wiki.gentoo.org/" + api_path: "api.php" + search_type: text + timeout: 10 + + - name: gitlab + engine: json_engine + paging: true + search_url: https://gitlab.com/api/v4/projects?search={query}&page={pageno} + url_query: web_url + title_query: name_with_namespace + content_query: description + page_size: 20 + categories: [it, repos] + shortcut: gl + timeout: 10.0 + disabled: true + about: + website: https://about.gitlab.com/ + wikidata_id: Q16639197 + official_api_documentation: https://docs.gitlab.com/ee/api/ + use_official_api: false + require_api_key: false + results: JSON + + - name: github + engine: github + shortcut: gh + + - name: codeberg + # https://docs.searxng.org/dev/engines/online/gitea.html + engine: gitea + base_url: https://codeberg.org + shortcut: cb + disabled: true + + - name: gitea.com + engine: gitea + base_url: https://gitea.com + shortcut: gitea + disabled: true + + - name: goodreads + engine: goodreads + shortcut: good + timeout: 4.0 + disabled: true + + - name: google + engine: google + shortcut: go + # additional_tests: + # android: *test_android + + - name: google images + engine: google_images + shortcut: goi + # additional_tests: + # android: *test_android + # dali: + # matrix: + # query: ['Dali Christ'] + # lang: ['en', 'de', 'fr', 'zh-CN'] + # result_container: + # - ['one_title_contains', 'Salvador'] + + - name: google news + engine: google_news + shortcut: gon + # additional_tests: + # android: *test_android + + - name: google videos + engine: google_videos + shortcut: gov + # additional_tests: + # android: *test_android + + - name: google scholar + engine: google_scholar + shortcut: gos + + - name: google play apps + engine: google_play + categories: [files, apps] + shortcut: gpa + play_categ: apps + disabled: true + + - name: google play movies + engine: google_play + categories: videos + shortcut: gpm + play_categ: movies + disabled: true + + - name: material icons + engine: material_icons + categories: images + shortcut: mi + disabled: true + + - name: gpodder + engine: json_engine + shortcut: gpod + timeout: 4.0 + paging: false + search_url: https://gpodder.net/search.json?q={query} + url_query: url + title_query: title + content_query: description + page_size: 19 + categories: music + disabled: true + about: + website: https://gpodder.net + wikidata_id: Q3093354 + official_api_documentation: https://gpoddernet.readthedocs.io/en/latest/api/ + use_official_api: false + requires_api_key: false + results: JSON + + - name: habrahabr + engine: xpath + paging: true + search_url: https://habr.com/en/search/page{pageno}/?q={query} + results_xpath: //article[contains(@class, "tm-articles-list__item")] + url_xpath: .//a[@class="tm-title__link"]/@href + title_xpath: .//a[@class="tm-title__link"] + content_xpath: .//div[contains(@class, "article-formatted-body")] + categories: it + timeout: 4.0 + disabled: true + shortcut: habr + about: + website: https://habr.com/ + wikidata_id: Q4494434 + official_api_documentation: https://habr.com/en/docs/help/api/ + use_official_api: false + require_api_key: false + results: HTML + + - name: hackernews + engine: hackernews + shortcut: hn + disabled: true + + - name: hex + engine: hex + shortcut: hex + disabled: true + # Valid values: name inserted_at updated_at total_downloads recent_downloads + sort_criteria: "recent_downloads" + page_size: 10 + + - name: crates.io + engine: crates + shortcut: crates + disabled: true + timeout: 6.0 + + - name: hoogle + engine: xpath + search_url: https://hoogle.haskell.org/?hoogle={query} + results_xpath: '//div[@class="result"]' + title_xpath: './/div[@class="ans"]//a' + url_xpath: './/div[@class="ans"]//a/@href' + content_xpath: './/div[@class="from"]' + page_size: 20 + categories: [it, packages] + shortcut: ho + about: + website: https://hoogle.haskell.org/ + wikidata_id: Q34010 + official_api_documentation: https://hackage.haskell.org/api + use_official_api: false + require_api_key: false + results: JSON + + - name: imdb + engine: imdb + shortcut: imdb + timeout: 6.0 + disabled: true + + - name: imgur + engine: imgur + shortcut: img + disabled: true + + - name: ina + engine: ina + shortcut: in + timeout: 6.0 + disabled: true + + - name: invidious + engine: invidious + # Instanes will be selected randomly, see https://api.invidious.io/ for + # instances that are stable (good uptime) and close to you. + base_url: + - https://invidious.io.lol + - https://invidious.fdn.fr + - https://yt.artemislena.eu + - https://invidious.tiekoetter.com + - https://invidious.flokinet.to + - https://vid.puffyan.us + - https://invidious.privacydev.net + - https://inv.tux.pizza + shortcut: iv + timeout: 3.0 + disabled: true + + - name: jisho + engine: jisho + shortcut: js + timeout: 3.0 + disabled: true + + - name: kickass + engine: kickass + base_url: + - https://kickasstorrents.to + - https://kickasstorrents.cr + - https://kickasstorrent.cr + - https://kickass.sx + - https://kat.am + shortcut: kc + timeout: 4.0 + disabled: true + + - name: lemmy communities + engine: lemmy + lemmy_type: Communities + shortcut: leco + + - name: lemmy users + engine: lemmy + network: lemmy communities + lemmy_type: Users + shortcut: leus + + - name: lemmy posts + engine: lemmy + network: lemmy communities + lemmy_type: Posts + shortcut: lepo + + - name: lemmy comments + engine: lemmy + network: lemmy communities + lemmy_type: Comments + shortcut: lecom + + - name: library genesis + engine: xpath + # search_url: https://libgen.is/search.php?req={query} + search_url: https://libgen.rs/search.php?req={query} + url_xpath: //a[contains(@href,"book/index.php?md5")]/@href + title_xpath: //a[contains(@href,"book/")]/text()[1] + content_xpath: //td/a[1][contains(@href,"=author")]/text() + categories: files + timeout: 7.0 + disabled: true + shortcut: lg + about: + website: https://libgen.fun/ + wikidata_id: Q22017206 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + - name: z-library + engine: zlibrary + shortcut: zlib + categories: files + timeout: 7.0 + disabled: true + + - name: library of congress + engine: loc + shortcut: loc + categories: images + + - name: libretranslate + engine: libretranslate + # https://github.com/LibreTranslate/LibreTranslate?tab=readme-ov-file#mirrors + base_url: + - https://translate.terraprint.co + - https://trans.zillyhuhn.com + # api_key: abc123 + shortcut: lt + disabled: true + + - name: lingva + engine: lingva + shortcut: lv + # set lingva instance in url, by default it will use the official instance + # url: https://lingva.thedaviddelta.com + + - name: lobste.rs + engine: xpath + search_url: https://lobste.rs/search?q={query}&what=stories&order=relevance + results_xpath: //li[contains(@class, "story")] + url_xpath: .//a[@class="u-url"]/@href + title_xpath: .//a[@class="u-url"] + content_xpath: .//a[@class="domain"] + categories: it + shortcut: lo + timeout: 5.0 + disabled: true + about: + website: https://lobste.rs/ + wikidata_id: Q60762874 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + - name: mastodon users + engine: mastodon + mastodon_type: accounts + base_url: https://mastodon.social + shortcut: mau + + - name: mastodon hashtags + engine: mastodon + mastodon_type: hashtags + base_url: https://mastodon.social + shortcut: mah + + # - name: matrixrooms + # engine: mrs + # # https://docs.searxng.org/dev/engines/online/mrs.html + # # base_url: https://mrs-api-host + # shortcut: mtrx + # disabled: true + + - name: mdn + shortcut: mdn + engine: json_engine + categories: [it] + paging: true + search_url: https://developer.mozilla.org/api/v1/search?q={query}&page={pageno} + results_query: documents + url_query: mdn_url + url_prefix: https://developer.mozilla.org + title_query: title + content_query: summary + about: + website: https://developer.mozilla.org + wikidata_id: Q3273508 + official_api_documentation: null + use_official_api: false + require_api_key: false + results: JSON + + - name: metacpan + engine: metacpan + shortcut: cpan + disabled: true + number_of_results: 20 + + # - name: meilisearch + # engine: meilisearch + # shortcut: mes + # enable_http: true + # base_url: http://localhost:7700 + # index: my-index + + - name: mixcloud + engine: mixcloud + shortcut: mc + + # MongoDB engine + # Required dependency: pymongo + # - name: mymongo + # engine: mongodb + # shortcut: md + # exact_match_only: false + # host: '127.0.0.1' + # port: 27017 + # enable_http: true + # results_per_page: 20 + # database: 'business' + # collection: 'reviews' # name of the db collection + # key: 'name' # key in the collection to search for + + - name: mozhi + engine: mozhi + base_url: + - https://mozhi.aryak.me + - https://translate.bus-hit.me + - https://nyc1.mz.ggtyler.dev + # mozhi_engine: google - see https://mozhi.aryak.me for supported engines + timeout: 4.0 + shortcut: mz + disabled: true + + - name: mwmbl + engine: mwmbl + # api_url: https://api.mwmbl.org + shortcut: mwm + disabled: true + + - name: npm + engine: npm + shortcut: npm + timeout: 5.0 + disabled: true + + - name: nyaa + engine: nyaa + shortcut: nt + disabled: true + + - name: mankier + engine: json_engine + search_url: https://www.mankier.com/api/v2/mans/?q={query} + results_query: results + url_query: url + title_query: name + content_query: description + categories: it + shortcut: man + about: + website: https://www.mankier.com/ + official_api_documentation: https://www.mankier.com/api + use_official_api: true + require_api_key: false + results: JSON + + # read https://docs.searxng.org/dev/engines/online/mullvad_leta.html + # - name: mullvadleta + # engine: mullvad_leta + # leta_engine: google # choose one of the following: google, brave + # use_cache: true # Only 100 non-cache searches per day, suggested only for private instances + # search_url: https://leta.mullvad.net + # categories: [general, web] + # shortcut: ml + + - name: odysee + engine: odysee + shortcut: od + disabled: true + + - name: openairedatasets + engine: json_engine + paging: true + search_url: https://api.openaire.eu/search/datasets?format=json&page={pageno}&size=10&title={query} + results_query: response/results/result + url_query: metadata/oaf:entity/oaf:result/children/instance/webresource/url/$ + title_query: metadata/oaf:entity/oaf:result/title/$ + content_query: metadata/oaf:entity/oaf:result/description/$ + content_html_to_text: true + categories: "science" + shortcut: oad + timeout: 5.0 + about: + website: https://www.openaire.eu/ + wikidata_id: Q25106053 + official_api_documentation: https://api.openaire.eu/ + use_official_api: false + require_api_key: false + results: JSON + + - name: openairepublications + engine: json_engine + paging: true + search_url: https://api.openaire.eu/search/publications?format=json&page={pageno}&size=10&title={query} + results_query: response/results/result + url_query: metadata/oaf:entity/oaf:result/children/instance/webresource/url/$ + title_query: metadata/oaf:entity/oaf:result/title/$ + content_query: metadata/oaf:entity/oaf:result/description/$ + content_html_to_text: true + categories: science + shortcut: oap + timeout: 5.0 + about: + website: https://www.openaire.eu/ + wikidata_id: Q25106053 + official_api_documentation: https://api.openaire.eu/ + use_official_api: false + require_api_key: false + results: JSON + + - name: openmeteo + engine: open_meteo + shortcut: om + disabled: true + + # - name: opensemanticsearch + # engine: opensemantic + # shortcut: oss + # base_url: 'http://localhost:8983/solr/opensemanticsearch/' + + - name: openstreetmap + engine: openstreetmap + shortcut: osm + + - name: openrepos + engine: xpath + paging: true + search_url: https://openrepos.net/search/node/{query}?page={pageno} + url_xpath: //li[@class="search-result"]//h3[@class="title"]/a/@href + title_xpath: //li[@class="search-result"]//h3[@class="title"]/a + content_xpath: //li[@class="search-result"]//div[@class="search-snippet-info"]//p[@class="search-snippet"] + categories: files + timeout: 4.0 + disabled: true + shortcut: or + about: + website: https://openrepos.net/ + wikidata_id: + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + - name: packagist + engine: json_engine + paging: true + search_url: https://packagist.org/search.json?q={query}&page={pageno} + results_query: results + url_query: url + title_query: name + content_query: description + categories: [it, packages] + disabled: true + timeout: 5.0 + shortcut: pack + about: + website: https://packagist.org + wikidata_id: Q108311377 + official_api_documentation: https://packagist.org/apidoc + use_official_api: true + require_api_key: false + results: JSON + + - name: pdbe + engine: pdbe + shortcut: pdb + # Hide obsolete PDB entries. Default is not to hide obsolete structures + # hide_obsolete: false + + - name: photon + engine: photon + shortcut: ph + + - name: pinterest + engine: pinterest + shortcut: pin + + - name: piped + engine: piped + shortcut: ppd + categories: videos + piped_filter: videos + timeout: 3.0 + + # URL to use as link and for embeds + frontend_url: https://srv.piped.video + # Instance will be selected randomly, for more see https://piped-instances.kavin.rocks/ + backend_url: + - https://pipedapi.kavin.rocks + - https://pipedapi-libre.kavin.rocks + - https://pipedapi.adminforge.de + + - name: piped.music + engine: piped + network: piped + shortcut: ppdm + categories: music + piped_filter: music_songs + timeout: 3.0 + + - name: piratebay + engine: piratebay + shortcut: tpb + # You may need to change this URL to a proxy if piratebay is blocked in your + # country + url: https://thepiratebay.org/ + timeout: 3.0 + + - name: pixiv + shortcut: pv + engine: pixiv + disabled: true + inactive: true + pixiv_image_proxies: + - https://pximg.example.org + # A proxy is required to load the images. Hosting an image proxy server + # for Pixiv: + # --> https://pixivfe.pages.dev/hosting-image-proxy-server/ + # Proxies from public instances. Ask the public instances owners if they + # agree to receive traffic from SearXNG! + # --> https://codeberg.org/VnPower/PixivFE#instances + # --> https://github.com/searxng/searxng/pull/3192#issuecomment-1941095047 + # image proxy of https://pixiv.cat + # - https://i.pixiv.cat + # image proxy of https://www.pixiv.pics + # - https://pximg.cocomi.eu.org + # image proxy of https://pixivfe.exozy.me + # - https://pximg.exozy.me + # image proxy of https://pixivfe.ducks.party + # - https://pixiv.ducks.party + # image proxy of https://pixiv.perennialte.ch + # - https://pximg.perennialte.ch + + - name: podcastindex + engine: podcastindex + shortcut: podcast + + # Required dependency: psychopg2 + # - name: postgresql + # engine: postgresql + # database: postgres + # username: postgres + # password: postgres + # limit: 10 + # query_str: 'SELECT * from my_table WHERE my_column = %(query)s' + # shortcut : psql + + - name: presearch + engine: presearch + search_type: search + categories: [general, web] + shortcut: ps + timeout: 4.0 + disabled: true + + - name: presearch images + engine: presearch + network: presearch + search_type: images + categories: [images, web] + timeout: 4.0 + shortcut: psimg + disabled: true + + - name: presearch videos + engine: presearch + network: presearch + search_type: videos + categories: [general, web] + timeout: 4.0 + shortcut: psvid + disabled: true + + - name: presearch news + engine: presearch + network: presearch + search_type: news + categories: [news, web] + timeout: 4.0 + shortcut: psnews + disabled: true + + - name: pub.dev + engine: xpath + shortcut: pd + search_url: https://pub.dev/packages?q={query}&page={pageno} + paging: true + results_xpath: //div[contains(@class,"packages-item")] + url_xpath: ./div/h3/a/@href + title_xpath: ./div/h3/a + content_xpath: ./div/div/div[contains(@class,"packages-description")]/span + categories: [packages, it] + timeout: 3.0 + disabled: true + first_page_num: 1 + about: + website: https://pub.dev/ + official_api_documentation: https://pub.dev/help/api + use_official_api: false + require_api_key: false + results: HTML + + - name: pubmed + engine: pubmed + shortcut: pub + timeout: 3.0 + + - name: pypi + shortcut: pypi + engine: pypi + + - name: qwant + qwant_categ: web + engine: qwant + disabled: true + shortcut: qw + categories: [general, web] + additional_tests: + rosebud: *test_rosebud + + - name: qwant news + qwant_categ: news + engine: qwant + shortcut: qwn + categories: news + network: qwant + + - name: qwant images + qwant_categ: images + engine: qwant + shortcut: qwi + categories: [images, web] + network: qwant + + - name: qwant videos + qwant_categ: videos + engine: qwant + shortcut: qwv + categories: [videos, web] + network: qwant + + # - name: library + # engine: recoll + # shortcut: lib + # base_url: 'https://recoll.example.org/' + # search_dir: '' + # mount_prefix: /export + # dl_prefix: 'https://download.example.org' + # timeout: 30.0 + # categories: files + # disabled: true + + # - name: recoll library reference + # engine: recoll + # base_url: 'https://recoll.example.org/' + # search_dir: reference + # mount_prefix: /export + # dl_prefix: 'https://download.example.org' + # shortcut: libr + # timeout: 30.0 + # categories: files + # disabled: true + + - name: radio browser + engine: radio_browser + shortcut: rb + + - name: reddit + engine: reddit + shortcut: re + page_size: 25 + disabled: true + + - name: rottentomatoes + engine: rottentomatoes + shortcut: rt + disabled: true + + # Required dependency: redis + # - name: myredis + # shortcut : rds + # engine: redis_server + # exact_match_only: false + # host: '127.0.0.1' + # port: 6379 + # enable_http: true + # password: '' + # db: 0 + + # tmp suspended: bad certificate + # - name: scanr structures + # shortcut: scs + # engine: scanr_structures + # disabled: true + + - name: searchmysite + engine: xpath + shortcut: sms + categories: general + paging: true + search_url: https://searchmysite.net/search/?q={query}&page={pageno} + results_xpath: //div[contains(@class,'search-result')] + url_xpath: .//a[contains(@class,'result-link')]/@href + title_xpath: .//span[contains(@class,'result-title-txt')]/text() + content_xpath: ./p[@id='result-hightlight'] + disabled: true + about: + website: https://searchmysite.net + + - name: sepiasearch + engine: sepiasearch + shortcut: sep + + - name: soundcloud + engine: soundcloud + shortcut: sc + + - name: stackoverflow + engine: stackexchange + shortcut: st + api_site: 'stackoverflow' + categories: [it, q&a] + + - name: askubuntu + engine: stackexchange + shortcut: ubuntu + api_site: 'askubuntu' + categories: [it, q&a] + + - name: internetarchivescholar + engine: internet_archive_scholar + shortcut: ias + timeout: 15.0 + + - name: superuser + engine: stackexchange + shortcut: su + api_site: 'superuser' + categories: [it, q&a] + + - name: discuss.python + engine: discourse + shortcut: dpy + base_url: 'https://discuss.python.org' + categories: [it, q&a] + disabled: true + + - name: caddy.community + engine: discourse + shortcut: caddy + base_url: 'https://caddy.community' + categories: [it, q&a] + disabled: true + + - name: pi-hole.community + engine: discourse + shortcut: pi + categories: [it, q&a] + base_url: 'https://discourse.pi-hole.net' + disabled: true + + - name: searchcode code + engine: searchcode_code + shortcut: scc + disabled: true + + # - name: searx + # engine: searx_engine + # shortcut: se + # instance_urls : + # - http://127.0.0.1:8888/ + # - ... + # disabled: true + + - name: semantic scholar + engine: semantic_scholar + disabled: true + shortcut: se + + # Spotify needs API credentials + # - name: spotify + # engine: spotify + # shortcut: stf + # api_client_id: ******* + # api_client_secret: ******* + + # - name: solr + # engine: solr + # shortcut: slr + # base_url: http://localhost:8983 + # collection: collection_name + # sort: '' # sorting: asc or desc + # field_list: '' # comma separated list of field names to display on the UI + # default_fields: '' # default field to query + # query_fields: '' # query fields + # enable_http: true + + # - name: springer nature + # engine: springer + # # get your API key from: https://dev.springernature.com/signup + # # working API key, for test & debug: "a69685087d07eca9f13db62f65b8f601" + # api_key: 'unset' + # shortcut: springer + # timeout: 15.0 + + - name: startpage + engine: startpage + shortcut: sp + timeout: 6.0 + disabled: true + additional_tests: + rosebud: *test_rosebud + + - name: tokyotoshokan + engine: tokyotoshokan + shortcut: tt + timeout: 6.0 + disabled: true + + - name: solidtorrents + engine: solidtorrents + shortcut: solid + timeout: 4.0 + base_url: + - https://solidtorrents.to + - https://bitsearch.to + + # For this demo of the sqlite engine download: + # https://liste.mediathekview.de/filmliste-v2.db.bz2 + # and unpack into searx/data/filmliste-v2.db + # Query to test: "!demo concert" + # + # - name: demo + # engine: sqlite + # shortcut: demo + # categories: general + # result_template: default.html + # database: searx/data/filmliste-v2.db + # query_str: >- + # SELECT title || ' (' || time(duration, 'unixepoch') || ')' AS title, + # COALESCE( NULLIF(url_video_hd,''), NULLIF(url_video_sd,''), url_video) AS url, + # description AS content + # FROM film + # WHERE title LIKE :wildcard OR description LIKE :wildcard + # ORDER BY duration DESC + + - name: tagesschau + engine: tagesschau + # when set to false, display URLs from Tagesschau, and not the actual source + # (e.g. NDR, WDR, SWR, HR, ...) + use_source_url: true + shortcut: ts + disabled: true + + - name: tmdb + engine: xpath + paging: true + categories: movies + search_url: https://www.themoviedb.org/search?page={pageno}&query={query} + results_xpath: //div[contains(@class,"movie") or contains(@class,"tv")]//div[contains(@class,"card")] + url_xpath: .//div[contains(@class,"poster")]/a/@href + thumbnail_xpath: .//img/@src + title_xpath: .//div[contains(@class,"title")]//h2 + content_xpath: .//div[contains(@class,"overview")] + shortcut: tm + disabled: true + + # Requires Tor + - name: torch + engine: xpath + paging: true + search_url: + http://xmh57jrknzkhv6y3ls3ubitzfqnkrwxhopf5aygthi7d6rplyvk3noyd.onion/cgi-bin/omega/omega?P={query}&DEFAULTOP=and + results_xpath: //table//tr + url_xpath: ./td[2]/a + title_xpath: ./td[2]/b + content_xpath: ./td[2]/small + categories: onions + enable_http: true + shortcut: tch + + # torznab engine lets you query any torznab compatible indexer. Using this + # engine in combination with Jackett opens the possibility to query a lot of + # public and private indexers directly from SearXNG. More details at: + # https://docs.searxng.org/dev/engines/online/torznab.html + # + # - name: Torznab EZTV + # engine: torznab + # shortcut: eztv + # base_url: http://localhost:9117/api/v2.0/indexers/eztv/results/torznab + # enable_http: true # if using localhost + # api_key: xxxxxxxxxxxxxxx + # show_magnet_links: true + # show_torrent_files: false + # # https://github.com/Jackett/Jackett/wiki/Jackett-Categories + # torznab_categories: # optional + # - 2000 + # - 5000 + + # tmp suspended - too slow, too many errors + # - name: urbandictionary + # engine : xpath + # search_url : https://www.urbandictionary.com/define.php?term={query} + # url_xpath : //*[@class="word"]/@href + # title_xpath : //*[@class="def-header"] + # content_xpath: //*[@class="meaning"] + # shortcut: ud + + - name: unsplash + engine: unsplash + shortcut: us + + - name: yandex music + engine: yandex_music + shortcut: ydm + disabled: true + # https://yandex.com/support/music/access.html + inactive: true + + - name: yahoo + engine: yahoo + shortcut: yh + disabled: true + + - name: yahoo news + engine: yahoo_news + shortcut: yhn + + - name: youtube + shortcut: yt + # You can use the engine using the official stable API, but you need an API + # key See: https://console.developers.google.com/project + # + # engine: youtube_api + # api_key: 'apikey' # required! + # + # Or you can use the html non-stable engine, activated by default + engine: youtube_noapi + + - name: dailymotion + engine: dailymotion + shortcut: dm + + - name: vimeo + engine: vimeo + shortcut: vm + disabled: true + + - name: wiby + engine: json_engine + paging: true + search_url: https://wiby.me/json/?q={query}&p={pageno} + url_query: URL + title_query: Title + content_query: Snippet + categories: [general, web] + shortcut: wib + disabled: true + about: + website: https://wiby.me/ + + - name: alexandria + engine: json_engine + shortcut: alx + categories: general + paging: true + search_url: https://api.alexandria.org/?a=1&q={query}&p={pageno} + results_query: results + title_query: title + url_query: url + content_query: snippet + timeout: 1.5 + disabled: true + about: + website: https://alexandria.org/ + official_api_documentation: https://github.com/alexandria-org/alexandria-api/raw/master/README.md + use_official_api: true + require_api_key: false + results: JSON + + - name: wikibooks + engine: mediawiki + weight: 0.5 + shortcut: wb + categories: [general, wikimedia] + base_url: "https://{language}.wikibooks.org/" + search_type: text + disabled: true + about: + website: https://www.wikibooks.org/ + wikidata_id: Q367 + + - name: wikinews + engine: mediawiki + shortcut: wn + categories: [news, wikimedia] + base_url: "https://{language}.wikinews.org/" + search_type: text + srsort: create_timestamp_desc + about: + website: https://www.wikinews.org/ + wikidata_id: Q964 + + - name: wikiquote + engine: mediawiki + weight: 0.5 + shortcut: wq + categories: [general, wikimedia] + base_url: "https://{language}.wikiquote.org/" + search_type: text + disabled: true + additional_tests: + rosebud: *test_rosebud + about: + website: https://www.wikiquote.org/ + wikidata_id: Q369 + + - name: wikisource + engine: mediawiki + weight: 0.5 + shortcut: ws + categories: [general, wikimedia] + base_url: "https://{language}.wikisource.org/" + search_type: text + disabled: true + about: + website: https://www.wikisource.org/ + wikidata_id: Q263 + + - name: wikispecies + engine: mediawiki + shortcut: wsp + categories: [general, science, wikimedia] + base_url: "https://species.wikimedia.org/" + search_type: text + disabled: true + about: + website: https://species.wikimedia.org/ + wikidata_id: Q13679 + tests: + wikispecies: + matrix: + query: "Campbell, L.I. et al. 2011: MicroRNAs" + lang: en + result_container: + - not_empty + - ['one_title_contains', 'Tardigrada'] + test: + - unique_results + + - name: wiktionary + engine: mediawiki + shortcut: wt + categories: [dictionaries, wikimedia] + base_url: "https://{language}.wiktionary.org/" + search_type: text + about: + website: https://www.wiktionary.org/ + wikidata_id: Q151 + + - name: wikiversity + engine: mediawiki + weight: 0.5 + shortcut: wv + categories: [general, wikimedia] + base_url: "https://{language}.wikiversity.org/" + search_type: text + disabled: true + about: + website: https://www.wikiversity.org/ + wikidata_id: Q370 + + - name: wikivoyage + engine: mediawiki + weight: 0.5 + shortcut: wy + categories: [general, wikimedia] + base_url: "https://{language}.wikivoyage.org/" + search_type: text + disabled: true + about: + website: https://www.wikivoyage.org/ + wikidata_id: Q373 + + - name: wikicommons.images + engine: wikicommons + shortcut: wc + categories: images + search_type: images + number_of_results: 10 + + - name: wikicommons.videos + engine: wikicommons + shortcut: wcv + categories: videos + search_type: videos + number_of_results: 10 + + - name: wikicommons.audio + engine: wikicommons + shortcut: wca + categories: music + search_type: audio + number_of_results: 10 + + - name: wikicommons.files + engine: wikicommons + shortcut: wcf + categories: files + search_type: files + number_of_results: 10 + + - name: wolframalpha + shortcut: wa + # You can use the engine using the official stable API, but you need an API + # key. See: https://products.wolframalpha.com/api/ + # + # engine: wolframalpha_api + # api_key: '' + # + # Or you can use the html non-stable engine, activated by default + engine: wolframalpha_noapi + timeout: 6.0 + categories: general + disabled: true + + - name: dictzone + engine: dictzone + shortcut: dc + + - name: mymemory translated + engine: translated + shortcut: tl + timeout: 5.0 + # You can use without an API key, but you are limited to 1000 words/day + # See: https://mymemory.translated.net/doc/usagelimits.php + # api_key: '' + + # Required dependency: mysql-connector-python + # - name: mysql + # engine: mysql_server + # database: mydatabase + # username: user + # password: pass + # limit: 10 + # query_str: 'SELECT * from mytable WHERE fieldname=%(query)s' + # shortcut: mysql + + - name: 1337x + engine: 1337x + shortcut: 1337x + disabled: true + + - name: duden + engine: duden + shortcut: du + disabled: true + + - name: seznam + shortcut: szn + engine: seznam + disabled: true + + # - name: deepl + # engine: deepl + # shortcut: dpl + # # You can use the engine using the official stable API, but you need an API key + # # See: https://www.deepl.com/pro-api?cta=header-pro-api + # api_key: '' # required! + # timeout: 5.0 + # disabled: true + + - name: mojeek + shortcut: mjk + engine: mojeek + categories: [general, web] + disabled: true + + - name: mojeek images + shortcut: mjkimg + engine: mojeek + categories: [images, web] + search_type: images + paging: false + disabled: true + + - name: mojeek news + shortcut: mjknews + engine: mojeek + categories: [news, web] + search_type: news + paging: false + disabled: true + + - name: moviepilot + engine: moviepilot + shortcut: mp + disabled: true + + - name: naver + shortcut: nvr + categories: [general, web] + engine: xpath + paging: true + search_url: https://search.naver.com/search.naver?where=webkr&sm=osp_hty&ie=UTF-8&query={query}&start={pageno} + url_xpath: //a[@class="link_tit"]/@href + title_xpath: //a[@class="link_tit"] + content_xpath: //div[@class="total_dsc_wrap"]/a + first_page_num: 1 + page_size: 10 + disabled: true + about: + website: https://www.naver.com/ + wikidata_id: Q485639 + official_api_documentation: https://developers.naver.com/docs/nmt/examples/ + use_official_api: false + require_api_key: false + results: HTML + language: ko + + - name: rubygems + shortcut: rbg + engine: xpath + paging: true + search_url: https://rubygems.org/search?page={pageno}&query={query} + results_xpath: /html/body/main/div/a[@class="gems__gem"] + url_xpath: ./@href + title_xpath: ./span/h2 + content_xpath: ./span/p + suggestion_xpath: /html/body/main/div/div[@class="search__suggestions"]/p/a + first_page_num: 1 + categories: [it, packages] + disabled: true + about: + website: https://rubygems.org/ + wikidata_id: Q1853420 + official_api_documentation: https://guides.rubygems.org/rubygems-org-api/ + use_official_api: false + require_api_key: false + results: HTML + + - name: peertube + engine: peertube + shortcut: ptb + paging: true + # alternatives see: https://instances.joinpeertube.org/instances + # base_url: https://tube.4aem.com + categories: videos + disabled: true + timeout: 6.0 + + - name: mediathekviewweb + engine: mediathekviewweb + shortcut: mvw + disabled: true + + - name: yacy + # https://docs.searxng.org/dev/engines/online/yacy.html + engine: yacy + categories: general + search_type: text + base_url: + - https://yacy.searchlab.eu + # see https://github.com/searxng/searxng/pull/3631#issuecomment-2240903027 + # - https://search.kyun.li + # - https://yacy.securecomcorp.eu + # - https://yacy.myserv.ca + # - https://yacy.nsupdate.info + # - https://yacy.electroncash.de + shortcut: ya + disabled: true + # if you aren't using HTTPS for your local yacy instance disable https + # enable_http: false + search_mode: 'global' + # timeout can be reduced in 'local' search mode + timeout: 5.0 + + - name: yacy images + engine: yacy + network: yacy + categories: images + search_type: image + shortcut: yai + disabled: true + # timeout can be reduced in 'local' search mode + timeout: 5.0 + + - name: rumble + engine: rumble + shortcut: ru + base_url: https://rumble.com/ + paging: true + categories: videos + disabled: true + + - name: livespace + engine: livespace + shortcut: ls + categories: videos + disabled: true + timeout: 5.0 + + - name: wordnik + engine: wordnik + shortcut: def + base_url: https://www.wordnik.com/ + categories: [dictionaries] + timeout: 5.0 + + - name: woxikon.de synonyme + engine: xpath + shortcut: woxi + categories: [dictionaries] + timeout: 5.0 + disabled: true + search_url: https://synonyme.woxikon.de/synonyme/{query}.php + url_xpath: //div[@class="upper-synonyms"]/a/@href + content_xpath: //div[@class="synonyms-list-group"] + title_xpath: //div[@class="upper-synonyms"]/a + no_result_for_http_status: [404] + about: + website: https://www.woxikon.de/ + wikidata_id: # No Wikidata ID + use_official_api: false + require_api_key: false + results: HTML + language: de + + - name: seekr news + engine: seekr + shortcut: senews + categories: news + seekr_category: news + disabled: true + + - name: seekr images + engine: seekr + network: seekr news + shortcut: seimg + categories: images + seekr_category: images + disabled: true + + - name: seekr videos + engine: seekr + network: seekr news + shortcut: sevid + categories: videos + seekr_category: videos + disabled: true + + - name: sjp.pwn + engine: sjp + shortcut: sjp + base_url: https://sjp.pwn.pl/ + timeout: 5.0 + disabled: true + + - name: stract + engine: stract + shortcut: str + disabled: true + + - name: svgrepo + engine: svgrepo + shortcut: svg + timeout: 10.0 + disabled: true + + - name: tootfinder + engine: tootfinder + shortcut: toot + + - name: voidlinux + engine: voidlinux + shortcut: void + disabled: true + + - name: wallhaven + engine: wallhaven + # api_key: abcdefghijklmnopqrstuvwxyz + shortcut: wh + + # wikimini: online encyclopedia for children + # The fulltext and title parameter is necessary for Wikimini because + # sometimes it will not show the results and redirect instead + - name: wikimini + engine: xpath + shortcut: wkmn + search_url: https://fr.wikimini.org/w/index.php?search={query}&title=Sp%C3%A9cial%3ASearch&fulltext=Search + url_xpath: //li/div[@class="mw-search-result-heading"]/a/@href + title_xpath: //li//div[@class="mw-search-result-heading"]/a + content_xpath: //li/div[@class="searchresult"] + categories: general + disabled: true + about: + website: https://wikimini.org/ + wikidata_id: Q3568032 + use_official_api: false + require_api_key: false + results: HTML + language: fr + + - name: wttr.in + engine: wttr + shortcut: wttr + timeout: 9.0 + + - name: yummly + engine: yummly + shortcut: yum + disabled: true + + - name: brave + engine: brave + shortcut: br + time_range_support: true + paging: true + categories: [general, web] + brave_category: search + # brave_spellcheck: true + + - name: brave.images + engine: brave + network: brave + shortcut: brimg + categories: [images, web] + brave_category: images + + - name: brave.videos + engine: brave + network: brave + shortcut: brvid + categories: [videos, web] + brave_category: videos + + - name: brave.news + engine: brave + network: brave + shortcut: brnews + categories: news + brave_category: news + + # - name: brave.goggles + # engine: brave + # network: brave + # shortcut: brgog + # time_range_support: true + # paging: true + # categories: [general, web] + # brave_category: goggles + # Goggles: # required! This should be a URL ending in .goggle + + - name: lib.rs + shortcut: lrs + engine: lib_rs + disabled: true + + - name: sourcehut + shortcut: srht + engine: xpath + paging: true + search_url: https://sr.ht/projects?page={pageno}&search={query} + results_xpath: (//div[@class="event-list"])[1]/div[@class="event"] + url_xpath: ./h4/a[2]/@href + title_xpath: ./h4/a[2] + content_xpath: ./p + first_page_num: 1 + categories: [it, repos] + disabled: true + about: + website: https://sr.ht + wikidata_id: Q78514485 + official_api_documentation: https://man.sr.ht/ + use_official_api: false + require_api_key: false + results: HTML + + - name: goo + shortcut: goo + engine: xpath + paging: true + search_url: https://search.goo.ne.jp/web.jsp?MT={query}&FR={pageno}0 + url_xpath: //div[@class="result"]/p[@class='title fsL1']/a/@href + title_xpath: //div[@class="result"]/p[@class='title fsL1']/a + content_xpath: //p[contains(@class,'url fsM')]/following-sibling::p + first_page_num: 0 + categories: [general, web] + disabled: true + timeout: 4.0 + about: + website: https://search.goo.ne.jp + wikidata_id: Q249044 + use_official_api: false + require_api_key: false + results: HTML + language: ja + + - name: bt4g + engine: bt4g + shortcut: bt4g + + - name: pkg.go.dev + engine: pkg_go_dev + shortcut: pgo + disabled: true + +# Doku engine lets you access to any Doku wiki instance: +# A public one or a privete/corporate one. +# - name: ubuntuwiki +# engine: doku +# shortcut: uw +# base_url: 'https://doc.ubuntu-fr.org' + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: git grep +# engine: command +# command: ['git', 'grep', '{{QUERY}}'] +# shortcut: gg +# tokens: [] +# disabled: true +# delimiter: +# chars: ':' +# keys: ['filepath', 'code'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: locate +# engine: command +# command: ['locate', '{{QUERY}}'] +# shortcut: loc +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: find +# engine: command +# command: ['find', '.', '-name', '{{QUERY}}'] +# query_type: path +# shortcut: fnd +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: pattern search in files +# engine: command +# command: ['fgrep', '{{QUERY}}'] +# shortcut: fgr +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: regex search in files +# engine: command +# command: ['grep', '{{QUERY}}'] +# shortcut: gr +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +doi_resolvers: + oadoi.org: 'https://oadoi.org/' + doi.org: 'https://doi.org/' + doai.io: 'https://dissem.in/' + sci-hub.se: 'https://sci-hub.se/' + sci-hub.st: 'https://sci-hub.st/' + sci-hub.ru: 'https://sci-hub.ru/' + +default_doi_resolver: 'oadoi.org' diff --git a/api/core/tools/provider/builtin/searxng/docker/uwsgi.ini b/api/core/tools/provider/builtin/searxng/docker/uwsgi.ini new file mode 100644 index 00000000000000..9db3d762649fc5 --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/docker/uwsgi.ini @@ -0,0 +1,54 @@ +[uwsgi] +# Who will run the code +uid = searxng +gid = searxng + +# Number of workers (usually CPU count) +# default value: %k (= number of CPU core, see Dockerfile) +workers = %k + +# Number of threads per worker +# default value: 4 (see Dockerfile) +threads = 4 + +# The right granted on the created socket +chmod-socket = 666 + +# Plugin to use and interpreter config +single-interpreter = true +master = true +plugin = python3 +lazy-apps = true +enable-threads = 4 + +# Module to import +module = searx.webapp + +# Virtualenv and python path +pythonpath = /usr/local/searxng/ +chdir = /usr/local/searxng/searx/ + +# automatically set processes name to something meaningful +auto-procname = true + +# Disable request logging for privacy +disable-logging = true +log-5xx = true + +# Set the max size of a request (request-body excluded) +buffer-size = 8192 + +# No keep alive +# See https://github.com/searx/searx-docker/issues/24 +add-header = Connection: close + +# Follow SIGTERM convention +# See https://github.com/searxng/searxng/issues/3427 +die-on-term + +# uwsgi serves the static files +static-map = /static=/usr/local/searxng/searx/static +# expires set to one day +static-expires = /* 86400 +static-gzip-all = True +offload-threads = 4 diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index 24b94b5ca4a391..ab354003e6f567 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -17,8 +17,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: tool_parameters={ "query": "SearXNG", "limit": 1, - "search_type": "page", - "result_type": "link" + "search_type": "general" }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/searxng/searxng.yaml b/api/core/tools/provider/builtin/searxng/searxng.yaml index 64bd428280d835..9554c93d5a0c53 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.yaml +++ b/api/core/tools/provider/builtin/searxng/searxng.yaml @@ -6,21 +6,18 @@ identity: zh_Hans: SearXNG description: en_US: A free internet metasearch engine. - zh_Hans: 开源互联网元搜索引擎 + zh_Hans: 开源免费的互联网元搜索引擎 icon: icon.svg tags: - search - productivity credentials_for_provider: searxng_base_url: - type: secret-input + type: text-input required: true label: en_US: SearXNG base URL zh_Hans: SearXNG base URL - help: - en_US: Please input your SearXNG base URL - zh_Hans: 请输入您的 SearXNG base URL placeholder: en_US: Please input your SearXNG base URL zh_Hans: 请输入您的 SearXNG base URL diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index 5d12553629abfe..dc835a8e8cbd5b 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -1,4 +1,3 @@ -import json from typing import Any import requests @@ -7,90 +6,11 @@ from core.tools.tool.builtin_tool import BuiltinTool -class SearXNGSearchResults(dict): - """Wrapper for search results.""" - - def __init__(self, data: str): - super().__init__(json.loads(data)) - self.__dict__ = self - - @property - def results(self) -> Any: - return self.get("results", []) - - class SearXNGSearchTool(BuiltinTool): """ Tool for performing a search using SearXNG engine. """ - SEARCH_TYPE: dict[str, str] = { - "page": "general", - "news": "news", - "image": "images", - # "video": "videos", - # "file": "files" - } - LINK_FILED: dict[str, str] = { - "page": "url", - "news": "url", - "image": "img_src", - # "video": "iframe_src", - # "file": "magnetlink" - } - TEXT_FILED: dict[str, str] = { - "page": "content", - "news": "content", - "image": "img_src", - # "video": "iframe_src", - # "file": "magnetlink" - } - - def _invoke_query(self, user_id: str, host: str, query: str, search_type: str, result_type: str, topK: int = 5) -> list[dict]: - """Run query and return the results.""" - - search_type = search_type.lower() - if search_type not in self.SEARCH_TYPE.keys(): - search_type= "page" - - response = requests.get(host, params={ - "q": query, - "format": "json", - "categories": self.SEARCH_TYPE[search_type] - }) - - if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - - search_results = SearXNGSearchResults(response.text).results[:topK] - - if result_type == 'link': - results = [] - if search_type == "page" or search_type == "news": - for r in search_results: - results.append(self.create_text_message( - text=f'{r["title"]}: {r.get(self.LINK_FILED[search_type], "")}' - )) - elif search_type == "image": - for r in search_results: - results.append(self.create_image_message( - image=r.get(self.LINK_FILED[search_type], "") - )) - else: - for r in search_results: - results.append(self.create_link_message( - link=r.get(self.LINK_FILED[search_type], "") - )) - - return results - else: - text = '' - for i, r in enumerate(search_results): - text += f'{i+1}: {r["title"]} - {r.get(self.TEXT_FILED[search_type], "")}\n' - - return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invoke the SearXNG search tool. @@ -103,23 +23,21 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url', None) + host = self.runtime.credentials.get('searxng_base_url') if not host: raise Exception('SearXNG api is required') - - query = tool_parameters.get('query') - if not query: - return self.create_text_message('Please input query') - - num_results = min(tool_parameters.get('num_results', 5), 20) - search_type = tool_parameters.get('search_type', 'page') or 'page' - result_type = tool_parameters.get('result_type', 'text') or 'text' - return self._invoke_query( - user_id=user_id, - host=host, - query=query, - search_type=search_type, - result_type=result_type, - topK=num_results - ) + response = requests.get(host, params={ + "q": tool_parameters.get('query'), + "format": "json", + "categories": tool_parameters.get('search_type', 'general') + }) + + if response.status_code != 200: + raise Exception(f'Error {response.status_code}: {response.text}') + + res = response.json().get("results", []) + if not res: + return self.create_text_message(f"No results found, get response: {response.content}") + + return [self.create_json_message(item) for item in res] diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml b/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml index 0edf1744f4b2f4..a5e448a30375b4 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml @@ -1,13 +1,13 @@ identity: name: searxng_search - author: Tice + author: Junytang label: en_US: SearXNG Search zh_Hans: SearXNG 搜索 description: human: - en_US: Perform searches on SearXNG and get results. - zh_Hans: 在 SearXNG 上进行搜索并获取结果。 + en_US: SearXNG is a free internet metasearch engine which aggregates results from more than 70 search services. + zh_Hans: SearXNG 是一个免费的互联网元搜索引擎,它从70多个不同的搜索服务中聚合搜索结果。 llm: Perform searches on SearXNG and get results. parameters: - name: query @@ -16,9 +16,6 @@ parameters: label: en_US: Query string zh_Hans: 查询语句 - human_description: - en_US: The search query. - zh_Hans: 搜索查询语句。 llm_description: Key words for searching form: llm - name: search_type @@ -27,63 +24,46 @@ parameters: label: en_US: search type zh_Hans: 搜索类型 - pt_BR: search type - human_description: - en_US: search type for page, news or image. - zh_Hans: 选择搜索的类型:网页,新闻,图片。 - pt_BR: search type for page, news or image. - default: Page + default: general options: - - value: Page + - value: general label: - en_US: Page - zh_Hans: 网页 - pt_BR: Page - - value: News + en_US: General + zh_Hans: 综合 + - value: images + label: + en_US: Images + zh_Hans: 图片 + - value: videos + label: + en_US: Videos + zh_Hans: 视频 + - value: news label: en_US: News zh_Hans: 新闻 - pt_BR: News - - value: Image + - value: map label: - en_US: Image - zh_Hans: 图片 - pt_BR: Image - form: form - - name: num_results - type: number - required: true - label: - en_US: Number of query results - zh_Hans: 返回查询数量 - human_description: - en_US: The number of query results. - zh_Hans: 返回查询结果的数量。 - form: form - default: 5 - min: 1 - max: 20 - - name: result_type - type: select - required: true - label: - en_US: result type - zh_Hans: 结果类型 - pt_BR: result type - human_description: - en_US: return a list of links or texts. - zh_Hans: 返回一个连接列表还是纯文本内容。 - pt_BR: return a list of links or texts. - default: text - options: - - value: link + en_US: Map + zh_Hans: 地图 + - value: music + label: + en_US: Music + zh_Hans: 音乐 + - value: it + label: + en_US: It + zh_Hans: 信息技术 + - value: science + label: + en_US: Science + zh_Hans: 科学 + - value: files label: - en_US: Link - zh_Hans: 链接 - pt_BR: Link - - value: text + en_US: Files + zh_Hans: 文件 + - value: social_media label: - en_US: Text - zh_Hans: 文本 - pt_BR: Text + en_US: Social Media + zh_Hans: 社交媒体 form: form diff --git a/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg new file mode 100644 index 00000000000000..ad6b384f7acd21 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py new file mode 100644 index 00000000000000..0df78280df3091 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -0,0 +1,19 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SiliconflowProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + url = "https://api.siliconflow.cn/v1/models" + headers = { + "accept": "application/json", + "authorization": f"Bearer {credentials.get('siliconFlow_api_key')}", + } + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + "SiliconFlow API key is invalid" + ) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml new file mode 100644 index 00000000000000..46be99f262f211 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml @@ -0,0 +1,21 @@ +identity: + author: hjlarry + name: siliconflow + label: + en_US: SiliconFlow + zh_CN: 硅基流动 + description: + en_US: The image generation API provided by SiliconFlow includes Flux and Stable Diffusion models. + zh_CN: 硅基流动提供的图片生成 API,包含 Flux 和 Stable Diffusion 模型。 + icon: icon.svg + tags: + - image +credentials_for_provider: + siliconFlow_api_key: + type: secret-input + required: true + label: + en_US: SiliconFlow API Key + placeholder: + en_US: Please input your SiliconFlow API key + url: https://cloud.siliconflow.cn/account/ak diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py new file mode 100644 index 00000000000000..ed9f4be574703a --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -0,0 +1,44 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +FLUX_URL = ( + "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" +) + + +class FluxTool(BuiltinTool): + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + payload = { + "prompt": tool_parameters.get("prompt"), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "seed": tool_parameters.get("seed"), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(FLUX_URL, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append( + self.create_image_message( + image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml new file mode 100644 index 00000000000000..2a0698700c1577 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml @@ -0,0 +1,73 @@ +identity: + name: flux + author: hjlarry + label: + en_US: Flux + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's flux schnell. + llm: This tool is used to generate image from prompt via SiliconFlow's flux schnell model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 768x1024 + label: + en_US: 768x1024 + - value: 576x1024 + label: + en_US: 576x1024 + - value: 512x1024 + label: + en_US: 512x1024 + - value: 1024x576 + label: + en_US: 1024x576 + - value: 768x512 + label: + en_US: 768x512 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成的图片大小 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + form: form + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py new file mode 100644 index 00000000000000..e8134a6565b281 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -0,0 +1,51 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SDURL = { + "sd_3": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image", + "sd_xl": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image", +} + + +class StableDiffusionTool(BuiltinTool): + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + model = tool_parameters.get("model", "sd_3") + url = SDURL.get(model) + + payload = { + "prompt": tool_parameters.get("prompt"), + "negative_prompt": tool_parameters.get("negative_prompt", ""), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "batch_size": tool_parameters.get("batch_size", 1), + "seed": tool_parameters.get("seed"), + "guidance_scale": tool_parameters.get("guidance_scale", 7.5), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append( + self.create_image_message( + image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml new file mode 100644 index 00000000000000..dce10adc87f222 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml @@ -0,0 +1,121 @@ +identity: + name: stable_diffusion + author: hjlarry + label: + en_US: Stable Diffusion + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's stable diffusion model. + llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: negative_prompt + type: string + label: + en_US: negative prompt + zh_Hans: 负面提示词 + human_description: + en_US: Describe what you don't want included in the image. + zh_Hans: 描述您不希望包含在图片中的内容。 + llm_description: Describe what you don't want included in the image. + form: llm + - name: model + type: select + required: true + options: + - value: sd_3 + label: + en_US: Stable Diffusion 3 + - value: sd_xl + label: + en_US: Stable Diffusion XL + default: sd_3 + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 1024x2048 + label: + en_US: 1024x2048 + - value: 1152x2048 + label: + en_US: 1152x2048 + - value: 1536x1024 + label: + en_US: 1536x1024 + - value: 1536x2048 + label: + en_US: 1536x2048 + - value: 2048x1152 + label: + en_US: 2048x1152 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成图片的大小 + form: form + - name: batch_size + type: number + required: true + default: 1 + min: 1 + max: 4 + label: + en_US: Number Images + zh_Hans: 生成图片的数量 + form: form + - name: guidance_scale + type: number + required: true + default: 7.5 + min: 0 + max: 100 + label: + en_US: Guidance Scale + zh_Hans: 与提示词紧密性 + human_description: + en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you. + zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + form: form + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 0c5ebc23ac5c51..4be9207d66abcc 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -27,7 +27,7 @@ "seed_resize_from_w": -1, # Samplers - # "sampler_name": "DPM++ 2M", + "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", @@ -178,6 +178,23 @@ def get_sd_models(self) -> list[str]: return [d['model_name'] for d in response.json()] except Exception as e: return [] + + def get_sample_methods(self) -> list[str]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + return [] + api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d['name'] for d in response.json()] + except Exception as e: + return [] def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -339,7 +356,27 @@ def get_runtime_parameters(self) -> list[ToolParameter]: label=I18nObject(en_US=i, zh_Hans=i) ) for i in models]) ) + except: pass - + + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter(name='sampler_name', + label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), + human_description=I18nObject( + en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + required=True, + default=sample_methods[0], + options=[ToolParameterOption( + value=i, + label=I18nObject(en_US=i, zh_Hans=i) + ) for i in sample_methods]) + ) return parameters diff --git a/api/core/tools/provider/builtin/stepfun/__init__.py b/api/core/tools/provider/builtin/stepfun/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/tools/provider/builtin/stepfun/_assets/icon.png b/api/core/tools/provider/builtin/stepfun/_assets/icon.png new file mode 100644 index 00000000000000..85b96d0c74c24c Binary files /dev/null and b/api/core/tools/provider/builtin/stepfun/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.py b/api/core/tools/provider/builtin/stepfun/stepfun.py new file mode 100644 index 00000000000000..e809b04546aef5 --- /dev/null +++ b/api/core/tools/provider/builtin/stepfun/stepfun.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.stepfun.tools.image import StepfunTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class StepfunProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + StepfunTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "prompt": "cute girl, blue eyes, white hair, anime style", + "size": "1024x1024", + "n": 1 + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.yaml b/api/core/tools/provider/builtin/stepfun/stepfun.yaml new file mode 100644 index 00000000000000..1f841ec369b5c3 --- /dev/null +++ b/api/core/tools/provider/builtin/stepfun/stepfun.yaml @@ -0,0 +1,46 @@ +identity: + author: Stepfun + name: stepfun + label: + en_US: Image-1X + zh_Hans: 阶跃星辰绘画 + pt_BR: Image-1X + description: + en_US: Image-1X + zh_Hans: 阶跃星辰绘画 + pt_BR: Image-1X + icon: icon.png + tags: + - image + - productivity +credentials_for_provider: + stepfun_api_key: + type: secret-input + required: true + label: + en_US: Stepfun API key + zh_Hans: 阶跃星辰API key + pt_BR: Stepfun API key + help: + en_US: Please input your stepfun API key + zh_Hans: 请输入你的阶跃星辰 API key + pt_BR: Please input your stepfun API key + placeholder: + en_US: Please input your stepfun API key + zh_Hans: 请输入你的阶跃星辰 API key + pt_BR: Please input your stepfun API key + stepfun_base_url: + type: text-input + required: false + label: + en_US: Stepfun base URL + zh_Hans: 阶跃星辰 base URL + pt_BR: Stepfun base URL + help: + en_US: Please input your Stepfun base URL + zh_Hans: 请输入你的阶跃星辰 base URL + pt_BR: Please input your Stepfun base URL + placeholder: + en_US: Please input your Stepfun base URL + zh_Hans: 请输入你的阶跃星辰 base URL + pt_BR: Please input your Stepfun base URL diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py new file mode 100644 index 00000000000000..5e544aada63b40 --- /dev/null +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -0,0 +1,72 @@ +import random +from typing import Any, Union + +from openai import OpenAI +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class StepfunTool(BuiltinTool): + """ Stepfun Image Generation Tool """ + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + base_url = self.runtime.credentials.get('stepfun_base_url', None) + if not base_url: + base_url = None + else: + base_url = str(URL(base_url) / 'v1') + + client = OpenAI( + api_key=self.runtime.credentials['stepfun_api_key'], + base_url=base_url, + ) + + extra_body = {} + model = tool_parameters.get('model', 'step-1x-medium') + if not model: + return self.create_text_message('Please input model name') + # prompt + prompt = tool_parameters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + + seed = tool_parameters.get('seed', 0) + if seed > 0: + extra_body['seed'] = seed + steps = tool_parameters.get('steps', 0) + if steps > 0: + extra_body['steps'] = steps + negative_prompt = tool_parameters.get('negative_prompt', '') + if negative_prompt: + extra_body['negative_prompt'] = negative_prompt + + # call openapi stepfun model + response = client.images.generate( + prompt=prompt, + model=model, + size=tool_parameters.get('size', '1024x1024'), + n=tool_parameters.get('n', 1), + extra_body= extra_body + ) + print(response) + + result = [] + for image in response.data: + result.append(self.create_image_message(image=image.url)) + result.append(self.create_json_message({ + "url": image.url, + })) + return result + + @staticmethod + def _generate_random_id(length=8): + characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' + random_id = ''.join(random.choices(characters, k=length)) + return random_id diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.yaml b/api/core/tools/provider/builtin/stepfun/tools/image.yaml new file mode 100644 index 00000000000000..dcc5bd2db2f5ba --- /dev/null +++ b/api/core/tools/provider/builtin/stepfun/tools/image.yaml @@ -0,0 +1,158 @@ +identity: + name: stepfun + author: Stepfun + label: + en_US: step-1x + zh_Hans: 阶跃星辰绘画 + pt_BR: step-1x + description: + en_US: step-1x is a powerful drawing tool by stepfun, you can draw the image based on your prompt + zh_Hans: step-1x 系列是阶跃星辰提供的强大的绘画工具,它可以根据您的提示词绘制出您想要的图像。 + pt_BR: step-1x is a powerful drawing tool by stepfun, you can draw the image based on your prompt +description: + human: + en_US: step-1x is a text to image tool + zh_Hans: step-1x 是一个文本/图像到图像的工具 + pt_BR: step-1x is a text to image tool + llm: step-1x is a tool used to generate images from text or image +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: Image prompt, you can check the official documentation of step-1x + zh_Hans: 图像提示词,您可以查看 step-1x 的官方文档 + pt_BR: Image prompt, you can check the official documentation of step-1x + llm_description: Image prompt of step-1x you should describe the image you want to generate as a list of words as possible as detailed + form: llm + - name: model + type: select + required: false + human_description: + en_US: used for selecting the model name + zh_Hans: 用于选择模型的名字 + pt_BR: used for selecting the model name + label: + en_US: Model Name + zh_Hans: 模型名字 + pt_BR: Model Name + form: form + options: + - value: step-1x-turbo + label: + en_US: turbo + zh_Hans: turbo + pt_BR: turbo + - value: step-1x-medium + label: + en_US: medium + zh_Hans: medium + pt_BR: medium + - value: step-1x-large + label: + en_US: large + zh_Hans: large + pt_BR: large + default: step-1x-medium + - name: size + type: select + required: false + human_description: + en_US: used for selecting the image size + zh_Hans: 用于选择图像大小 + pt_BR: used for selecting the image size + label: + en_US: Image size + zh_Hans: 图像大小 + pt_BR: Image size + form: form + options: + - value: 256x256 + label: + en_US: 256x256 + zh_Hans: 256x256 + pt_BR: 256x256 + - value: 512x512 + label: + en_US: 512x512 + zh_Hans: 512x512 + pt_BR: 512x512 + - value: 768x768 + label: + en_US: 768x768 + zh_Hans: 768x768 + pt_BR: 768x768 + - value: 1024x1024 + label: + en_US: 1024x1024 + zh_Hans: 1024x1024 + pt_BR: 1024x1024 + - value: 1280x800 + label: + en_US: 1280x800 + zh_Hans: 1280x800 + pt_BR: 1280x800 + - value: 800x1280 + label: + en_US: 800x1280 + zh_Hans: 800x1280 + pt_BR: 800x1280 + default: 1024x1024 + - name: n + type: number + required: true + human_description: + en_US: used for selecting the number of images + zh_Hans: 用于选择图像数量 + pt_BR: used for selecting the number of images + label: + en_US: Number of images + zh_Hans: 图像数量 + pt_BR: Number of images + form: form + default: 1 + min: 1 + max: 10 + - name: seed + type: number + required: false + label: + en_US: seed + zh_Hans: seed + pt_BR: seed + human_description: + en_US: seed + zh_Hans: seed + pt_BR: seed + form: form + default: 10 + - name: steps + type: number + required: false + label: + en_US: Steps + zh_Hans: Steps + pt_BR: Steps + human_description: + en_US: Steps + zh_Hans: Steps + pt_BR: Steps + form: form + default: 10 + - name: negative_prompt + type: string + required: false + label: + en_US: Negative prompt + zh_Hans: Negative prompt + pt_BR: Negative prompt + human_description: + en_US: Negative prompt + zh_Hans: Negative prompt + pt_BR: Negative prompt + form: form + default: (worst quality:1.3), (nsfw), low quality diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f7911fea1db18d..f14abac76777da 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import VariableEntity +from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -18,6 +18,13 @@ from models.tools import WorkflowToolProvider from models.workflow import Workflow +VARIABLE_TO_PARAMETER_TYPE_MAPPING = { + VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, + VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, + VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, + VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, +} + class WorkflowToolProviderController(ToolProviderController): provider_id: str @@ -28,7 +35,7 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont if not app: raise ValueError('app not found') - + controller = WorkflowToolProviderController(**{ 'identity': { 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', @@ -46,7 +53,7 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont 'credentials_schema': {}, 'provider_id': db_provider.id or '', }) - + # init tools controller.tools = [controller._get_db_provider_tool(db_provider, app)] @@ -56,7 +63,7 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont @property def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ get db provider tool @@ -93,23 +100,11 @@ def fetch_workflow_variable(variable_name: str) -> VariableEntity: if variable: parameter_type = None options = None - if variable.type in [ - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.PARAGRAPH, - ]: - parameter_type = ToolParameter.ToolParameterType.STRING - elif variable.type in [ - VariableEntity.Type.SELECT - ]: - parameter_type = ToolParameter.ToolParameterType.SELECT - elif variable.type in [ - VariableEntity.Type.NUMBER - ]: - parameter_type = ToolParameter.ToolParameterType.NUMBER - else: + if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: raise ValueError(f'unsupported variable type {variable.type}') - - if variable.type == VariableEntity.Type.SELECT and variable.options: + parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] + + if variable.type == VariableEntityType.SELECT and variable.options: options = [ ToolParameterOption( value=option, @@ -200,7 +195,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: """ if self.tools is not None: return self.tools - + db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, @@ -208,11 +203,11 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if not db_providers: return [] - + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ get tool by name @@ -226,5 +221,5 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: for tool in self.tools: if tool.identity.name == tool_name: return tool - + return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 69e3dfa0612e8a..38f10032e2282b 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -144,7 +144,7 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], path_params[parameter['name']] = value elif parameter['in'] == 'query': - params[parameter['name']] = value + if value !='': params[parameter['name']] = value elif parameter['in'] == 'cookie': cookies[parameter['name']] = value diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 1a0933af160365..7cb7c033bbe9f1 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -177,10 +177,12 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document dataset_id=dataset.id, query=query, top_k=self.top_k, - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] + reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 397ff7966e46f2..a7e70af6286544 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -14,6 +14,7 @@ 'reranking_provider_name': '', 'reranking_model_name': '' }, + 'reranking_mode': 'reranking_model', 'top_k': 2, 'score_threshold_enabled': False } @@ -71,14 +72,16 @@ def _run(self, query: str) -> str: else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), dataset_id=dataset.id, query=query, top_k=self.top_k, - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] + reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) else: diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 5d561911d12564..d990131b5fbbfd 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -2,13 +2,12 @@ from collections.abc import Mapping from copy import deepcopy from enum import Enum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator from pydantic_core.core_schema import ValidationInfo from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileVar from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, @@ -23,6 +22,9 @@ from core.tools.tool_file_manager import ToolFileManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class Tool(BaseModel, ABC): identity: Optional[ToolIdentity] = None @@ -76,7 +78,7 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': description=self.description.model_copy() if self.description else None, runtime=Tool.Runtime(**runtime), ) - + @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ @@ -84,7 +86,7 @@ def tool_provider_type(self) -> ToolProviderType: :return: the tool provider type """ - + def load_variables(self, variables: ToolRuntimeVariablePool): """ load variables from database @@ -99,7 +101,7 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: """ if not self.variables: return - + self.variables.set_file(self.identity.name, variable_name, image_key) def set_text_variable(self, variable_name: str, text: str) -> None: @@ -108,9 +110,9 @@ def set_text_variable(self, variable_name: str, text: str) -> None: """ if not self.variables: return - + self.variables.set_text(self.identity.name, variable_name, text) - + def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ get a variable @@ -120,14 +122,14 @@ def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ if not self.variables: return None - + if isinstance(name, Enum): name = name.value - + for variable in self.variables.pool: if variable.name == name: return variable - + return None def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: @@ -138,9 +140,9 @@ def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: """ if not self.variables: return None - + return self.get_variable(self.VARIABLE_KEY.IMAGE) - + def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ get a variable file @@ -151,7 +153,7 @@ def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: variable = self.get_variable(name) if not variable: return None - + if not isinstance(variable, ToolRuntimeImageVariable): return None @@ -160,9 +162,9 @@ def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) if not file_binary: return None - + return file_binary[0] - + def list_variables(self) -> list[ToolRuntimeVariable]: """ list all variables @@ -171,9 +173,9 @@ def list_variables(self) -> list[ToolRuntimeVariable]: """ if not self.variables: return [] - + return self.variables.pool - + def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ list all image variables @@ -182,9 +184,9 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ if not self.variables: return [] - + result = [] - + for variable in self.variables.pool: if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): result.append(variable) @@ -225,7 +227,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - + def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ validate the credentials @@ -244,7 +246,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: :return: the runtime parameters """ return self.parameters or [] - + def get_all_runtime_parameters(self) -> list[ToolParameter]: """ get all runtime parameters @@ -278,7 +280,7 @@ def get_all_runtime_parameters(self) -> list[ToolParameter]: parameters.append(parameter) return parameters - + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: """ create an image message @@ -286,18 +288,18 @@ def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessa :param image: the url of the image :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, + message=image, save_as=save_as) - - def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: + + def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, message='', meta={ 'file_var': file_var }, save_as='') - + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: """ create a link message @@ -305,10 +307,10 @@ def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage :param link: the url of the link :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, + message=link, save_as=save_as) - + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: """ create a text message @@ -321,7 +323,7 @@ def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage message=text, save_as=save_as ) - + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 071081303c3b2a..12e498e76d8cd5 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -72,6 +72,7 @@ def _invoke( result.append(self.create_file_var_message(file)) result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) + result.append(self.create_json_message(outputs)) return result diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d7ddb40e6be310..4a0188af49d6ff 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -10,14 +10,11 @@ from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, -) +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort @@ -26,10 +23,7 @@ from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolConfigurationManager, - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db @@ -38,6 +32,7 @@ logger = logging.getLogger(__name__) + class ToolManager: _builtin_provider_lock = Lock() _builtin_providers = {} @@ -107,7 +102,7 @@ def get_tool_runtime(cls, provider_type: str, tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -346,7 +341,7 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non provider_class = load_single_subclass_from_source( module_name=f'core.tools.provider.builtin.{provider}.{provider}', script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), + 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider @@ -414,6 +409,15 @@ def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProvider # append builtin providers for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name + ): + continue + user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), @@ -473,7 +477,7 @@ def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProvider @classmethod def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + ApiToolProviderController, dict[str, Any]]: """ get the api provider @@ -593,4 +597,5 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> else: raise ValueError(f"provider type {provider_type} not found") + ToolManager.load_builtin_providers_cache() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py new file mode 100644 index 00000000000000..e6b288868f47b0 --- /dev/null +++ b/api/core/tools/utils/feishu_api_utils.py @@ -0,0 +1,143 @@ +import httpx + +from extensions.ext_redis import redis_client + + +class FeishuRequest: + def __init__(self, app_id: str, app_secret: str): + self.app_id = app_id + self.app_secret = app_secret + + @property + def tenant_access_token(self): + feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" + if redis_client.exists(feishu_tenant_access_token): + return redis_client.get(feishu_tenant_access_token).decode() + res = self.get_tenant_access_token(self.app_id, self.app_secret) + redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) + return res.get("tenant_access_token") + + def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, + params: dict = None): + headers = { + "Content-Type": "application/json", + "user-agent": "Dify", + } + if require_token: + headers["tenant-access-token"] = f"{self.tenant_access_token}" + res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json() + if res.get("code") != 0: + raise Exception(res) + return res + + def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: + """ + API url: https://open.feishu.cn/document/server-docs/authentication-management/access-token/tenant_access_token_internal + Example Response: + { + "code": 0, + "msg": "ok", + "tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3", + "expire": 7200 + } + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token" + payload = { + "app_id": app_id, + "app_secret": app_secret + } + res = self._send_request(url, require_token=False, payload=payload) + return res + + def create_document(self, title: str, content: str, folder_token: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/create + Example Response: + { + "data": { + "title": "title", + "url": "https://svi136aogf123.feishu.cn/docx/VWbvd4fEdoW0WSxaY1McQTz8n7d", + "type": "docx", + "token": "VWbvd4fEdoW0WSxaY1McQTz8n7d" + }, + "log_id": "021721281231575fdbddc0200ff00060a9258ec0000103df61b5d", + "code": 0, + "msg": "创建飞书文档成功,请查看" + } + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/create_document" + payload = { + "title": title, + "content": content, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + return res.get("data") + + def write_document(self, document_id: str, content: str, position: str = "start") -> dict: + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" + payload = { + "document_id": document_id, + "content": content, + "position": position + } + res = self._send_request(url, payload=payload) + return res.get("data") + + def get_document_raw_content(self, document_id: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content + Example Response: + { + "code": 0, + "msg": "success", + "data": { + "content": "云文档\n多人实时协同,插入一切元素。不仅是在线文档,更是强大的创作和互动工具\n云文档:专为协作而生\n" + } + } + """ + params = { + "document_id": document_id, + } + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/get_document_raw_content" + res = self._send_request(url, method="get", params=params) + return res.get("data").get("content") + + def list_document_block(self, document_id: str, page_token: str, page_size: int = 500) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/list + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/list_document_block" + params = { + "document_id": document_id, + "page_size": page_size, + "page_token": page_token, + } + res = self._send_request(url, method="get", params=params) + return res.get("data") + + def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/im-v1/message/create + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/message/send_bot_message" + params = { + "receive_id_type": receive_id_type, + } + payload = { + "receive_id": receive_id, + "msg_type": msg_type, + "content": content, + } + res = self._send_request(url, params=params, payload=payload) + return res.get("data") + + def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/message/send_webhook_message" + payload = { + "webhook": webhook, + "msg_type": msg_type, + "content": content, + } + res = self._send_request(url, require_token=False, payload=payload) + return res diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ef9e5b67ae2ab6..564b9d3e14c15e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,7 @@ import logging from mimetypes import guess_extension -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -27,12 +27,12 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) - + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' result.append(ToolInvokeMessage( @@ -55,14 +55,14 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], # if message is str, encode it to bytes if isinstance(message.message, str): message.message = message.message.encode('utf-8') - + file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, mimetype=mimetype ) - + url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image @@ -81,7 +81,7 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], meta=message.meta.copy() if message.meta is not None else {}, )) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var: FileVar = message.meta.get('file_var') + file_var = message.meta.get('file_var') if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) @@ -103,7 +103,7 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], result.append(message) return result - + @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: - return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file + return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 21155a696031f6..f751c430963cda 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -26,7 +26,6 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any raise YAMLError(f'Failed to load YAML file {file_path}: {e}') except Exception as e: if ignore_error: - logger.debug(f'Failed to load YAML file {file_path}: {e}') return default_value else: raise e diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 996aae94c20e27..025453567bfc1b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -4,13 +4,14 @@ from pydantic import BaseModel -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class NodeType(Enum): """ Node Types. """ + START = 'start' END = 'end' ANSWER = 'answer' @@ -23,10 +24,12 @@ class NodeType(Enum): HTTP_REQUEST = 'http-request' TOOL = 'tool' VARIABLE_AGGREGATOR = 'variable-aggregator' + # TODO: merge this into VARIABLE_AGGREGATOR VARIABLE_ASSIGNER = 'variable-assigner' LOOP = 'loop' ITERATION = 'iteration' PARAMETER_EXTRACTOR = 'parameter-extractor' + CONVERSATION_VARIABLE_ASSIGNER = 'assigner' @classmethod def value_of(cls, value: str) -> 'NodeType': @@ -42,33 +45,11 @@ def value_of(cls, value: str) -> 'NodeType': raise ValueError(f'invalid node type value {value}') -class SystemVariable(Enum): - """ - System Variables. - """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - - @classmethod - def value_of(cls, value: str) -> 'SystemVariable': - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') - - class NodeRunMetadataKey(Enum): """ Node Run Metadata Key. """ + TOTAL_TOKENS = 'total_tokens' TOTAL_PRICE = 'total_price' CURRENCY = 'currency' @@ -81,6 +62,7 @@ class NodeRunResult(BaseModel): """ Node Run Result. """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index a27b4261e486bc..8120b2ac780021 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,21 +6,23 @@ from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariableKey VariableValue = Union[str, int, float, dict, list, FileVar] -SYSTEM_VARIABLE_NODE_ID = 'sys' -ENVIRONMENT_VARIABLE_NODE_ID = 'env' +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" class VariablePool: def __init__( self, - system_variables: Mapping[SystemVariable, Any], + system_variables: Mapping[SystemVariableKey, Any], user_inputs: Mapping[str, Any], environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable] | None = None, ) -> None: # system variables # for example: @@ -44,9 +46,13 @@ def __init__( self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool - for var in environment_variables or []: + for var in environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) + # Add conversation variables to the variable pool + for var in conversation_variables or []: + self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. @@ -62,7 +68,7 @@ def add(self, selector: Sequence[str], value: Any, /) -> None: None """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") if value is None: return @@ -89,13 +95,13 @@ def get(self, selector: Sequence[str], /) -> Segment | None: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value - @deprecated('This method is deprecated, use `get` instead.') + @deprecated("This method is deprecated, use `get` instead.") def get_any(self, selector: Sequence[str], /) -> Any | None: """ Retrieves the value from the variable pool based on the given selector. @@ -110,7 +116,7 @@ def get_any(self, selector: Sequence[str], /) -> Any | None: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py new file mode 100644 index 00000000000000..da65f6b1fb95e2 --- /dev/null +++ b/api/core/workflow/enums.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class SystemVariableKey(str, Enum): + """ + System Variables. + """ + + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index d8c812e7ef1244..3d9cf52771e146 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -8,6 +8,7 @@ from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from models import WorkflowNodeExecutionStatus class UserFrom(Enum): @@ -91,14 +92,19 @@ def run(self, variable_pool: VariablePool) -> NodeRunResult: :param variable_pool: variable pool :return: """ - result = self._run( - variable_pool=variable_pool - ) - - self.node_run_result = result - return result - - def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: + try: + result = self._run( + variable_pool=variable_pool + ) + self.node_run_result = result + return result + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + + def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: """ Publish text chunk :param text: chunk text diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index fafd43e5bc9d8e..335991ae87c285 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -11,19 +11,10 @@ from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_NUMBER = dify_config.CODE_MAX_NUMBER -MIN_NUMBER = dify_config.CODE_MIN_NUMBER -MAX_PRECISION = 20 -MAX_DEPTH = 5 -MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH -MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH -MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH -MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH - class CodeNode(BaseNode): _node_data_cls = CodeNodeData - node_type = NodeType.CODE + _node_type = NodeType.CODE @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: @@ -48,8 +39,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) + node_data = cast(CodeNodeData, self.node_data) # Get code language code_language = node_data.code_language @@ -68,7 +58,6 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: language=code_language, code=code, inputs=variables, - dependencies=node_data.dependencies ) # Transform result @@ -94,10 +83,14 @@ def _check_string(self, value: str, variable: str) -> str: :return: """ if not isinstance(value, str): - raise ValueError(f"Output variable `{variable}` must be a string") - - if len(value) > MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters') + if isinstance(value, type(None)): + return None + else: + raise ValueError(f"Output variable `{variable}` must be a string") + + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + raise ValueError(f'The length of output variable `{variable}` must be' + f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters') return value.replace('\x00', '') @@ -109,15 +102,20 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f :return: """ if not isinstance(value, int | float): - raise ValueError(f"Output variable `{variable}` must be a number") + if isinstance(value, type(None)): + return None + else: + raise ValueError(f"Output variable `{variable}` must be a number") - if value > MAX_NUMBER or value < MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.') + if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + raise ValueError(f'Output variable `{variable}` is out of range,' + f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.') + if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError(f'Output variable `{variable}` has too high precision,' + f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') return value @@ -130,8 +128,8 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code :param output_schema: output schema :return: """ - if depth > MAX_DEPTH: - raise ValueError("Depth limit reached, object too deep.") + if depth > dify_config.CODE_MAX_DEPTH: + raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -157,28 +155,31 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code elif isinstance(output_value, list): first_element = output_value[0] if len(output_value) > 0 else None if first_element is not None: - if isinstance(first_element, int | float) and all(isinstance(value, int | float) for value in output_value): + if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value): for i, value in enumerate(output_value): self._check_number( value=value, variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' ) - elif isinstance(first_element, str) and all(isinstance(value, str) for value in output_value): + elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value): for i, value in enumerate(output_value): self._check_string( value=value, variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' ) - elif isinstance(first_element, dict) and all(isinstance(value, dict) for value in output_value): + elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value): for i, value in enumerate(output_value): - self._transform_result( - result=value, - output_schema=None, - prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', - depth=depth + 1 - ) + if value is not None: + self._transform_result( + result=value, + output_schema=None, + prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', + depth=depth + 1 + ) else: raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + elif isinstance(output_value, type(None)): + pass else: raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') @@ -193,16 +194,19 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code if output_config.type == 'object': # check if output is object if not isinstance(result.get(output_name), dict): - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + if isinstance(result.get(output_name), type(None)): + transformed_result[output_name] = None + else: + raise ValueError( + f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + ) + else: + transformed_result[output_name] = self._transform_result( + result=result[output_name], + output_schema=output_config.children, + prefix=f'{prefix}.{output_name}', + depth=depth + 1 ) - - transformed_result[output_name] = self._transform_result( - result=result[output_name], - output_schema=output_config.children, - prefix=f'{prefix}.{output_name}', - depth=depth + 1 - ) elif output_config.type == 'number': # check if number available transformed_result[output_name] = self._check_number( @@ -218,68 +222,83 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code elif output_config.type == 'array[number]': # check if array of number available if not isinstance(result[output_name], list): - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' - ) - - if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: - raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.' - ) + if isinstance(result[output_name], type(None)): + transformed_result[output_name] = None + else: + raise ValueError( + f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + else: + if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: + raise ValueError( + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' + ) - transformed_result[output_name] = [ - self._check_number( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) - for i, value in enumerate(result[output_name]) - ] + transformed_result[output_name] = [ + self._check_number( + value=value, + variable=f'{prefix}{dot}{output_name}[{i}]' + ) + for i, value in enumerate(result[output_name]) + ] elif output_config.type == 'array[string]': # check if array of string available if not isinstance(result[output_name], list): - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' - ) - - if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: - raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.' - ) + if isinstance(result[output_name], type(None)): + transformed_result[output_name] = None + else: + raise ValueError( + f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + else: + if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: + raise ValueError( + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' + ) - transformed_result[output_name] = [ - self._check_string( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) - for i, value in enumerate(result[output_name]) - ] + transformed_result[output_name] = [ + self._check_string( + value=value, + variable=f'{prefix}{dot}{output_name}[{i}]' + ) + for i, value in enumerate(result[output_name]) + ] elif output_config.type == 'array[object]': # check if array of object available if not isinstance(result[output_name], list): - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' - ) - - if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH: - raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.' - ) - - for i, value in enumerate(result[output_name]): - if not isinstance(value, dict): + if isinstance(result[output_name], type(None)): + transformed_result[output_name] = None + else: + raise ValueError( + f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + else: + if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' ) + + for i, value in enumerate(result[output_name]): + if not isinstance(value, dict): + if isinstance(value, type(None)): + pass + else: + raise ValueError( + f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + ) - transformed_result[output_name] = [ - self._transform_result( - result=value, - output_schema=output_config.children, - prefix=f'{prefix}{dot}{output_name}[{i}]', - depth=depth + 1 - ) - for i, value in enumerate(result[output_name]) - ] + transformed_result[output_name] = [ + None if value is None else self._transform_result( + result=value, + output_schema=output_config.children, + prefix=f'{prefix}{dot}{output_name}[{i}]', + depth=depth + 1 + ) + for i, value in enumerate(result[output_name]) + ] else: raise ValueError(f'Output type {output_config.type} is not supported.') diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 83a5416d57ce6a..c0701ecccd2132 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -3,7 +3,6 @@ from pydantic import BaseModel from core.helper.code_executor.code_executor import CodeLanguage -from core.helper.code_executor.entities import CodeDependency from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -16,8 +15,12 @@ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] children: Optional[dict[str, 'Output']] = None + class Dependency(BaseModel): + name: str + version: str + variables: list[VariableSelector] code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[CodeDependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 90d644e0e22d3c..c066d469d8c553 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -5,10 +5,6 @@ from configs import dify_config from core.workflow.entities.base_node_data_entities import BaseNodeData -MAX_CONNECT_TIMEOUT = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT -MAX_READ_TIMEOUT = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT -MAX_WRITE_TIMEOUT = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT - class HttpRequestNodeAuthorizationConfig(BaseModel): type: Literal[None, 'basic', 'bearer', 'custom'] @@ -41,9 +37,9 @@ class HttpRequestNodeBody(BaseModel): class HttpRequestNodeTimeout(BaseModel): - connect: int = MAX_CONNECT_TIMEOUT - read: int = MAX_READ_TIMEOUT - write: int = MAX_WRITE_TIMEOUT + connect: int = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT + read: int = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT + write: int = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT class HttpRequestNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 3c24c0a018d2b9..d16bff58bda12b 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -18,11 +18,6 @@ ) from core.workflow.utils.variable_template_parser import VariableTemplateParser -MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE -READABLE_MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE -MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE -READABLE_MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE - class HttpExecutorResponse: headers: dict[str, str] @@ -237,16 +232,14 @@ def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutor else: raise ValueError(f'Invalid response type {type(response)}') - if executor_response.is_file: - if executor_response.size > MAX_BINARY_SIZE: - raise ValueError( - f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.' - ) - else: - if executor_response.size > MAX_TEXT_SIZE: - raise ValueError( - f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.' - ) + threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) return executor_response @@ -337,7 +330,7 @@ def _format_template( if variable is None: raise ValueError(f'Variable {variable_selector.variable} not found') if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"') + value = variable.replace('"', '\\"').replace('\n', '\\n') else: value = variable variable_value_mapping[variable_selector.variable] = value diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index bbe5f9ad43f561..037a7a1848ccfa 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -3,6 +3,7 @@ from os import path from typing import cast +from configs import dify_config from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager @@ -11,9 +12,6 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import ( - MAX_CONNECT_TIMEOUT, - MAX_READ_TIMEOUT, - MAX_WRITE_TIMEOUT, HttpRequestNodeData, HttpRequestNodeTimeout, ) @@ -21,9 +19,9 @@ from models.workflow import WorkflowNodeExecutionStatus HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=min(10, MAX_CONNECT_TIMEOUT), - read=min(60, MAX_READ_TIMEOUT), - write=min(20, MAX_WRITE_TIMEOUT), + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, ) @@ -43,9 +41,9 @@ def get_default_config(cls, filters: dict | None = None) -> dict: 'body': {'type': 'none'}, 'timeout': { **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': MAX_CONNECT_TIMEOUT, - 'max_read_timeout': MAX_READ_TIMEOUT, - 'max_write_timeout': MAX_WRITE_TIMEOUT, + 'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + 'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + 'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, } @@ -92,17 +90,15 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: }, ) - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: timeout = node_data.timeout if timeout is None: return HTTP_REQUEST_DEFAULT_TIMEOUT timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT) timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.read = min(timeout.read, MAX_READ_TIMEOUT) timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT) return timeout @classmethod @@ -133,9 +129,6 @@ def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVa """ files = [] mimetype, file_binary = response.extract_file() - # if not image, return directly - if 'image' not in mimetype: - return files if mimetype: # extract filename from url diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 4431259a57543b..eb8921b5266691 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,14 +1,13 @@ import json from collections.abc import Generator from copy import deepcopy -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage @@ -23,8 +22,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -38,6 +38,10 @@ from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.file.file_obj import FileVar + + class LLMNode(BaseNode): _node_data_cls = LLMNodeData @@ -70,7 +74,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_inputs = {} # fetch files - files: list[FileVar] = self._fetch_files(node_data, variable_pool) + files = self._fetch_files(node_data, variable_pool) if files: node_inputs['#files#'] = [file.to_dict() for file in files] @@ -90,7 +94,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) + query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, @@ -109,7 +113,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: } # handle invoke result - result_text, usage = self._invoke_llm( + result_text, usage, finish_reason = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -125,7 +129,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: outputs = { 'text': result_text, - 'usage': jsonable_encoder(usage) + 'usage': jsonable_encoder(usage), + 'finish_reason': finish_reason } return NodeRunResult( @@ -163,14 +168,14 @@ def _invoke_llm(self, node_data_model: ModelConfig, ) # handle invoke result - text, usage = self._handle_invoke_result( + text, usage, finish_reason = self._handle_invoke_result( invoke_result=invoke_result ) # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage + return text, usage, finish_reason def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ @@ -182,6 +187,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage prompt_messages = [] full_text = '' usage = None + finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text @@ -197,12 +203,15 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage if not usage and result.delta.usage: usage = result.delta.usage + if not finish_reason and result.delta.finish_reason: + finish_reason = result.delta.finish_reason + if not usage: usage = LLMUsage.empty_usage() - return full_text, usage - - def _transform_chat_messages(self, + return full_text, usage, finish_reason + + def _transform_chat_messages(self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ @@ -249,13 +258,13 @@ def parse_dict(d: dict) -> str: # check if it's a context structure if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: return d['content'] - + # else, parse the dict try: return json.dumps(d, ensure_ascii=False) except Exception: return str(d) - + if isinstance(value, str): value = value elif isinstance(value, list): @@ -321,7 +330,7 @@ def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: """ Fetch files :param node_data: node data @@ -331,7 +340,7 @@ def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> l if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) + files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) if not files: return [] @@ -496,7 +505,7 @@ def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None @@ -520,7 +529,7 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, query: Optional[str], query_prompt_template: Optional[str], inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ @@ -668,10 +677,10 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) variable_mapping['#context#'] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 2e1464efcef1e6..f4057d50f38988 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -63,7 +63,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) # handle invoke result - result_text, usage = self._invoke_llm( + result_text, usage, finish_reason = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -93,6 +93,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: prompt_messages=prompt_messages ), 'usage': jsonable_encoder(usage), + 'finish_reason': finish_reason } outputs = { 'class_name': category_name diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 0bd5f203bf72a5..b81ce15bd74c96 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,3 +1,7 @@ +from collections.abc import Sequence + +from pydantic import Field + from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ - variables: list[VariableEntity] = [] + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 661b403d32f8ca..54e66bd6718901 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -17,16 +17,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: :param variable_pool: variable pool :return: """ - # Get cleaned inputs - cleaned_inputs = dict(variable_pool.user_inputs) + node_inputs = dict(variable_pool.user_inputs) + system_inputs = variable_pool.system_variables - for var in variable_pool.system_variables: - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] + for var in system_inputs: + node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=cleaned_inputs, - outputs=cleaned_inputs + inputs=node_inputs, + outputs=node_inputs ) @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 969b1c241e52a8..ccce9ef3607624 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,19 +2,20 @@ from os import path from typing import Any, cast -from core.app.segments import parser +from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class ToolNode(BaseNode): @@ -118,6 +119,7 @@ def _generate_parameters( for parameter_name in node_data.tool_parameters: parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: + result[parameter_name] = None continue if parameter.type == ToolParameter.ToolParameterType.FILE: result[parameter_name] = [ @@ -139,9 +141,9 @@ def _generate_parameters( return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - # FIXME: ensure this is a ArrayVariable contains FileVariable. - variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - return [file_var.value for file_var in variable.value] if variable else [] + variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) + return list(variable.value) if variable else [] def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): """ diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index cea88334b90738..e5de38dc0faa26 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -17,7 +17,7 @@ class Group(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'array', 'object'] + output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] variables: list[list[str]] group_name: str @@ -30,4 +30,4 @@ class VariableAssignerNodeData(BaseNodeData): type: str = 'variable-assigner' output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None \ No newline at end of file + advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py new file mode 100644 index 00000000000000..d791d51523ec0f --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -0,0 +1,8 @@ +from .node import VariableAssignerNode +from .node_data import VariableAssignerData, WriteMode + +__all__ = [ + 'VariableAssignerNode', + 'VariableAssignerData', + 'WriteMode', +] diff --git a/api/core/workflow/nodes/variable_assigner/exc.py b/api/core/workflow/nodes/variable_assigner/exc.py new file mode 100644 index 00000000000000..914be2225642cd --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/exc.py @@ -0,0 +1,2 @@ +class VariableAssignerNodeError(Exception): + pass diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py new file mode 100644 index 00000000000000..8c2adcabb9dc05 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -0,0 +1,92 @@ +from typing import cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.segments import SegmentType, Variable, factory +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from extensions.ext_database import db +from models import ConversationVariable, WorkflowNodeExecutionStatus + +from .exc import VariableAssignerNodeError +from .node_data import VariableAssignerData, WriteMode + + +class VariableAssignerNode(BaseNode): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + data = cast(VariableAssignerData, self.node_data) + + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = variable_pool.get(data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError('assigned variable not found') + + match data.write_mode: + case WriteMode.OVER_WRITE: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_variable = original_variable.model_copy(update={'value': income_value.value}) + + case WriteMode.APPEND: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={'value': updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + + # Over write the variable. + variable_pool.add(data.assigned_variable_selector, updated_variable) + + # TODO: Move database operation to the pipeline. + # Update conversation variable. + conversation_id = variable_pool.get(['sys', 'conversation_id']) + if not conversation_id: + raise VariableAssignerNodeError('conversation_id not found') + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'value': income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError('conversation variable not found in the database') + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return factory.build_segment([]) + case SegmentType.OBJECT: + return factory.build_segment({}) + case SegmentType.STRING: + return factory.build_segment('') + case SegmentType.NUMBER: + return factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f'unsupported variable type: {t}') diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py new file mode 100644 index 00000000000000..b3652b68024265 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -0,0 +1,19 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class WriteMode(str, Enum): + OVER_WRITE = 'over-write' + APPEND = 'append' + CLEAR = 'clear' + + +class VariableAssignerData(BaseNodeData): + title: str = 'Variable Assigner' + desc: Optional[str] = 'Assign a value to a variable' + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index bd2b3eafa7a8b1..3157eedfee5238 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,13 +3,13 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, cast +import contexts from configs import dify_config -from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError @@ -30,6 +30,7 @@ from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode from extensions.ext_database import db from models.workflow import ( Workflow, @@ -51,7 +52,8 @@ NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, } logger = logging.getLogger(__name__) @@ -94,19 +96,18 @@ def run_workflow( user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, - user_inputs: Mapping[str, Any], - system_inputs: Mapping[SystemVariable, Any], callbacks: Sequence[WorkflowCallback], - call_depth: int = 0 + call_depth: int = 0, + variable_pool: VariablePool | None = None, ) -> None: """ :param workflow: Workflow instance :param user_id: user id :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files + :param invoke_from: invoke from :param callbacks: workflow callbacks :param call_depth: call depth + :param variable_pool: variable pool """ # fetch workflow graph graph = workflow.graph_dict @@ -122,18 +123,14 @@ def run_workflow( if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') - # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - ) workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) # init workflow run state + if not variable_pool: + variable_pool = contexts.workflow_variable_pool.get() workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), @@ -403,6 +400,7 @@ def single_step_run_workflow_node(self, workflow: Workflow, system_variables={}, user_inputs={}, environment_variables=workflow.environment_variables, + conversation_variables=workflow.conversation_variables, ) if node_cls is None: @@ -468,6 +466,7 @@ def single_step_run_iteration_workflow_node(self, workflow: Workflow, system_variables={}, user_inputs={}, environment_variables=workflow.environment_variables, + conversation_variables=workflow.conversation_variables, ) # variable selector to variable mapping diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 64e4e719ab9be8..1edc558676747a 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -4,7 +4,7 @@ set -e if [[ "${MIGRATION_ENABLED}" == "true" ]]; then echo "Running migrations" - flask db upgrade + flask upgrade-db fi if [[ "${MODE}" == "worker" ]]; then @@ -20,11 +20,11 @@ if [[ "${MODE}" == "worker" ]]; then CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" fi - exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel INFO \ + exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL} \ -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace,app_deletion} elif [[ "${MODE}" == "beat" ]]; then - exec celery -A app.celery beat --loglevel INFO + exec celery -A app.celery beat --loglevel ${LOG_LEVEL} else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/events/app_event.py b/api/events/app_event.py index 67a5982527f7b6..f2ce71bbbb3632 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -1,13 +1,13 @@ from blinker import signal # sender: app -app_was_created = signal('app-was-created') +app_was_created = signal("app-was-created") # sender: app, kwargs: app_model_config -app_model_config_was_updated = signal('app-model-config-was-updated') +app_model_config_was_updated = signal("app-model-config-was-updated") # sender: app, kwargs: published_workflow -app_published_workflow_was_updated = signal('app-published-workflow-was-updated') +app_published_workflow_was_updated = signal("app-published-workflow-was-updated") # sender: app, kwargs: synced_draft_workflow -app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced') +app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") diff --git a/api/events/dataset_event.py b/api/events/dataset_event.py index d4a2b6f313c13a..750b7424e2347b 100644 --- a/api/events/dataset_event.py +++ b/api/events/dataset_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: dataset -dataset_was_deleted = signal('dataset-was-deleted') +dataset_was_deleted = signal("dataset-was-deleted") diff --git a/api/events/document_event.py b/api/events/document_event.py index f95326630b2b7a..2c5a416a5e0c91 100644 --- a/api/events/document_event.py +++ b/api/events/document_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_was_deleted = signal('document-was-deleted') +document_was_deleted = signal("document-was-deleted") diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 42f1c70614c49a..7caa2d1cc9f3f2 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -5,5 +5,11 @@ @dataset_was_deleted.connect def handle(sender, **kwargs): dataset = sender - clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, - dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) + clean_dataset_task.delay( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + ) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index 24022da15f81ee..00a66f50ad9319 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -5,7 +5,7 @@ @document_was_deleted.connect def handle(sender, **kwargs): document_id = sender - dataset_id = kwargs.get('dataset_id') - doc_form = kwargs.get('doc_form') - file_id = kwargs.get('file_id') + dataset_id = kwargs.get("dataset_id") + doc_form = kwargs.get("doc_form") + file_id = kwargs.get("file_id") clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 68dae5a5537cd7..72a135e73d4ca5 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,21 +14,25 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get('document_ids', None) + document_ids = kwargs.get("document_ids", None) documents = [] start_at = time.perf_counter() for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document) + .filter( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) + .first() + ) if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -38,8 +42,8 @@ def handle(sender, **kwargs): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 31084ce0fe8bdc..57412cc4ad0af2 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -10,7 +10,7 @@ def handle(sender, **kwargs): installed_app = InstalledApp( tenant_id=app.tenant_id, app_id=app.id, - app_owner_tenant_id=app.tenant_id + app_owner_tenant_id=app.tenant_id, ) db.session.add(installed_app) db.session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index f0eb7159b631e0..1515661b2d45b8 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -7,15 +7,18 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender - account = kwargs.get('account') + account = kwargs.get("account") site = Site( app_id=app.id, title=app.name, - icon = app.icon, - icon_background = app.icon_background, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, ) db.session.add(site) diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8cf52bf8f5d0b8..843a2320968ced 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -8,7 +8,7 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return @@ -39,7 +39,7 @@ def handle(sender, **kwargs): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_config.model: + if "gpt-4" in model_config.model: used_quota = 20 else: used_quota = 1 @@ -50,6 +50,6 @@ def handle(sender, **kwargs): Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 1f6da34ee24d56..f96bb5ef74b62e 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,8 +8,8 @@ @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): - if node_data.get('data', {}).get('type') == NodeType.TOOL.value: + for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( @@ -23,7 +23,7 @@ def handle(sender, **kwargs): tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, provider_type=tool_entity.provider_type, - identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}', ) manager.delete_tool_parameters_cache() except: diff --git a/api/events/event_handlers/document_index_event.py b/api/events/event_handlers/document_index_event.py index 9c4e055debdd9e..3d463fe5b35acf 100644 --- a/api/events/event_handlers/document_index_event.py +++ b/api/events/event_handlers/document_index_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_index_created = signal('document-index-created') +document_index_created = signal("document-index-created") diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 2b202c53d0b883..59375b1a0b1a81 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -7,13 +7,11 @@ @app_model_config_was_updated.connect def handle(sender, **kwargs): app = sender - app_model_config = kwargs.get('app_model_config') + app_model_config = kwargs.get("app_model_config") dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -29,16 +27,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) or [] + tools = agent_mode.get("tools", []) or [] for tool in tools: if len(list(tool.keys())) != 1: continue @@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict - datasets = dataset_configs.get('datasets', {}) or {} - for dataset in datasets.get('datasets', []) or []: + datasets = dataset_configs.get("datasets", {}) or {} + for dataset in datasets.get("datasets", []) or []: keys = list(dataset.keys()) - if len(keys) == 1 and keys[0] == 'dataset': - if dataset['dataset'].get('id'): - dataset_ids.add(dataset['dataset'].get('id')) + if len(keys) == 1 and keys[0] == "dataset": + if dataset["dataset"].get("id"): + dataset_ids.add(dataset["dataset"].get("id")) return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 996b1e96910b93..333b85ecb2907a 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -11,13 +11,11 @@ @app_published_workflow_was_updated.connect def handle(sender, **kwargs): app = sender - published_workflow = kwargs.get('published_workflow') + published_workflow = kwargs.get("published_workflow") published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -33,16 +31,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: if not graph: return dataset_ids - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) # fetch all knowledge retrieval nodes - knowledge_retrieval_nodes = [node for node in nodes - if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value] + knowledge_retrieval_nodes = [ + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + ] if not knowledge_retrieval_nodes: return dataset_ids for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get('data', {})) + node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) dataset_ids.update(node_data.dataset_ids) except Exception as e: continue diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 6188f1a0850ac4..a80572c0debb1a 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -9,13 +9,13 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, - Provider.provider_name == application_generate_entity.model_conf.provider - ).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)}) + Provider.provider_name == application_generate_entity.model_conf.provider, + ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) db.session.commit() diff --git a/api/events/message_event.py b/api/events/message_event.py index 21da83f2496af5..6576c35c453c95 100644 --- a/api/events/message_event.py +++ b/api/events/message_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: message, kwargs: conversation -message_was_created = signal('message-was-created') +message_was_created = signal("message-was-created") diff --git a/api/events/tenant_event.py b/api/events/tenant_event.py index 942f709917bf61..d99feaac40896d 100644 --- a/api/events/tenant_event.py +++ b/api/events/tenant_event.py @@ -1,7 +1,7 @@ from blinker import signal # sender: tenant -tenant_was_created = signal('tenant-was-created') +tenant_was_created = signal("tenant-was-created") # sender: tenant -tenant_was_updated = signal('tenant-was-updated') +tenant_was_updated = signal("tenant-was-updated") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index ae9a07534084e7..f5ec7c1759cb9d 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -17,7 +17,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: backend=app.config["CELERY_BACKEND"], task_ignore_result=True, ) - + # Add SSL options to the Celery configuration ssl_options = { "ssl_cert_reqs": None, @@ -35,7 +35,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) - + celery_app.set_default() app.extensions["celery"] = celery_app @@ -45,18 +45,15 @@ def __call__(self, *args: object, **kwargs: object) -> object: ] day = app.config["CELERY_BEAT_SCHEDULER_TIME"] beat_schedule = { - 'clean_embedding_cache_task': { - 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', - 'schedule': timedelta(days=day), + "clean_embedding_cache_task": { + "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", + "schedule": timedelta(days=day), + }, + "clean_unused_datasets_task": { + "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", + "schedule": timedelta(days=day), }, - 'clean_unused_datasets_task': { - 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(days=day), - } } - celery_app.conf.update( - beat_schedule=beat_schedule, - imports=imports - ) + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 1dbaffcfb0dc27..38e67749fcc994 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -2,15 +2,14 @@ def init_app(app: Flask): - if app.config.get('API_COMPRESSION_ENABLED'): + if app.config.get("API_COMPRESSION_ENABLED"): from flask_compress import Compress - app.config['COMPRESS_MIMETYPES'] = [ - 'application/json', - 'image/svg+xml', - 'text/html', + app.config["COMPRESS_MIMETYPES"] = [ + "application/json", + "image/svg+xml", + "text/html", ] compress = Compress() compress.init_app(app) - diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c248e173a252c7..f6ffa536343afc 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -2,11 +2,11 @@ from sqlalchemy import MetaData POSTGRES_INDEXES_NAMING_CONVENTION = { - 'ix': '%(column_0_label)s_idx', - 'uq': '%(table_name)s_%(column_0_name)s_key', - 'ck': '%(table_name)s_%(constraint_name)s_check', - 'fk': '%(table_name)s_%(column_0_name)s_fkey', - 'pk': '%(table_name)s_pkey', + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", } metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index ec3a5cc112b871..b435294abc23ba 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -14,67 +14,69 @@ def is_inited(self) -> bool: return self._client is not None def init_app(self, app: Flask): - if app.config.get('MAIL_TYPE'): - if app.config.get('MAIL_DEFAULT_SEND_FROM'): - self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM') - - if app.config.get('MAIL_TYPE') == 'resend': - api_key = app.config.get('RESEND_API_KEY') + if app.config.get("MAIL_TYPE"): + if app.config.get("MAIL_DEFAULT_SEND_FROM"): + self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") + + if app.config.get("MAIL_TYPE") == "resend": + api_key = app.config.get("RESEND_API_KEY") if not api_key: - raise ValueError('RESEND_API_KEY is not set') + raise ValueError("RESEND_API_KEY is not set") - api_url = app.config.get('RESEND_API_URL') + api_url = app.config.get("RESEND_API_URL") if api_url: resend.api_url = api_url resend.api_key = api_key self._client = resend.Emails - elif app.config.get('MAIL_TYPE') == 'smtp': + elif app.config.get("MAIL_TYPE") == "smtp": from libs.smtp import SMTPClient - if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): - raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') - if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'): - raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS') + + if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): + raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") + if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): + raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") self._client = SMTPClient( - server=app.config.get('SMTP_SERVER'), - port=app.config.get('SMTP_PORT'), - username=app.config.get('SMTP_USERNAME'), - password=app.config.get('SMTP_PASSWORD'), - _from=app.config.get('MAIL_DEFAULT_SEND_FROM'), - use_tls=app.config.get('SMTP_USE_TLS'), - opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS') + server=app.config.get("SMTP_SERVER"), + port=app.config.get("SMTP_PORT"), + username=app.config.get("SMTP_USERNAME"), + password=app.config.get("SMTP_PASSWORD"), + _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), + use_tls=app.config.get("SMTP_USE_TLS"), + opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), ) else: - raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE'))) + raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) else: - logging.warning('MAIL_TYPE is not set') - + logging.warning("MAIL_TYPE is not set") def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: - raise ValueError('Mail client is not initialized') + raise ValueError("Mail client is not initialized") if not from_ and self._default_send_from: from_ = self._default_send_from if not from_: - raise ValueError('mail from is not set') + raise ValueError("mail from is not set") if not to: - raise ValueError('mail to is not set') + raise ValueError("mail to is not set") if not subject: - raise ValueError('mail subject is not set') + raise ValueError("mail subject is not set") if not html: - raise ValueError('mail html is not set') - - self._client.send({ - "from": from_, - "to": to, - "subject": subject, - "html": html - }) + raise ValueError("mail html is not set") + + self._client.send( + { + "from": from_, + "to": to, + "subject": subject, + "html": html, + } + ) def init_app(app: Flask): diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 23d7768d4d0f5a..d5fb162fd8f2fb 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -6,18 +6,21 @@ def init_app(app): connection_class = Connection - if app.config.get('REDIS_USE_SSL'): + if app.config.get("REDIS_USE_SSL"): connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool(**{ - 'host': app.config.get('REDIS_HOST'), - 'port': app.config.get('REDIS_PORT'), - 'username': app.config.get('REDIS_USERNAME'), - 'password': app.config.get('REDIS_PASSWORD'), - 'db': app.config.get('REDIS_DB'), - 'encoding': 'utf-8', - 'encoding_errors': 'strict', - 'decode_responses': False - }, connection_class=connection_class) + redis_client.connection_pool = redis.ConnectionPool( + **{ + "host": app.config.get("REDIS_HOST"), + "port": app.config.get("REDIS_PORT"), + "username": app.config.get("REDIS_USERNAME"), + "password": app.config.get("REDIS_PASSWORD"), + "db": app.config.get("REDIS_DB"), + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + }, + connection_class=connection_class, + ) - app.extensions['redis'] = redis_client + app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index f05c10bc08926e..3b7b0a37f4e8e3 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -1,3 +1,4 @@ +import openai import sentry_sdk from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration @@ -5,16 +6,13 @@ def init_app(app): - if app.config.get('SENTRY_DSN'): + if app.config.get("SENTRY_DSN"): sentry_sdk.init( - dsn=app.config.get('SENTRY_DSN'), - integrations=[ - FlaskIntegration(), - CeleryIntegration() - ], - ignore_errors=[HTTPException, ValueError], - traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0), - profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0), - environment=app.config.get('DEPLOY_ENV'), - release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}" + dsn=app.config.get("SENTRY_DSN"), + integrations=[FlaskIntegration(), CeleryIntegration()], + ignore_errors=[HTTPException, ValueError, openai.APIStatusError], + traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), + profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), + environment=app.config.get("DEPLOY_ENV"), + release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 38db1c6ce103af..e6c4352577fc3f 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -17,31 +17,19 @@ def __init__(self): self.storage_runner = None def init_app(self, app: Flask): - storage_type = app.config.get('STORAGE_TYPE') - if storage_type == 's3': - self.storage_runner = S3Storage( - app=app - ) - elif storage_type == 'azure-blob': - self.storage_runner = AzureStorage( - app=app - ) - elif storage_type == 'aliyun-oss': - self.storage_runner = AliyunStorage( - app=app - ) - elif storage_type == 'google-storage': - self.storage_runner = GoogleStorage( - app=app - ) - elif storage_type == 'tencent-cos': - self.storage_runner = TencentStorage( - app=app - ) - elif storage_type == 'oci-storage': - self.storage_runner = OCIStorage( - app=app - ) + storage_type = app.config.get("STORAGE_TYPE") + if storage_type == "s3": + self.storage_runner = S3Storage(app=app) + elif storage_type == "azure-blob": + self.storage_runner = AzureStorage(app=app) + elif storage_type == "aliyun-oss": + self.storage_runner = AliyunStorage(app=app) + elif storage_type == "google-storage": + self.storage_runner = GoogleStorage(app=app) + elif storage_type == "tencent-cos": + self.storage_runner = TencentStorage(app=app) + elif storage_type == "oci-storage": + self.storage_runner = OCIStorage(app=app) else: self.storage_runner = LocalStorage(app=app) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index b81a8691f15457..bee237fc17fe86 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -8,38 +8,52 @@ class AliyunStorage(BaseStorage): - """Implementation for aliyun storage. - """ + """Implementation for aliyun storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME') + self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") + self.folder = app.config.get("ALIYUN_OSS_PATH") oss_auth_method = aliyun_s3.Auth region = None - if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4': + if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": oss_auth_method = aliyun_s3.AuthV4 - region = app_config.get('ALIYUN_OSS_REGION') - oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')) + region = app_config.get("ALIYUN_OSS_REGION") + oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) self.client = aliyun_s3.Bucket( oss_auth, - app_config.get('ALIYUN_OSS_ENDPOINT'), + app_config.get("ALIYUN_OSS_ENDPOINT"), self.bucket_name, connect_timeout=30, region=region, ) def save(self, filename, data): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename self.client.put_object(filename, data) def load_once(self, filename: str) -> bytes: + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + with closing(self.client.get_object(filename)) as obj: data = obj.read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + with closing(self.client.get_object(filename)) as obj: while chunk := obj.read(4096): yield chunk @@ -47,10 +61,24 @@ def generate(filename: str = filename) -> Generator: return generate() def download(self, filename, target_filepath): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + self.client.get_object_to_file(filename, target_filepath) def exists(self, filename): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + return self.client.object_exists(filename) def delete(self, filename): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename self.client.delete_object(filename) diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py index af3e7ef84911ff..ca8cbb9188b5c9 100644 --- a/api/extensions/storage/azure_storage.py +++ b/api/extensions/storage/azure_storage.py @@ -9,16 +9,15 @@ class AzureStorage(BaseStorage): - """Implementation for azure storage. - """ + """Implementation for azure storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME') - self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL') - self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME') - self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY') + self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") + self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") + self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") + self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") def save(self, filename, data): client = self._sync_client() @@ -39,6 +38,7 @@ def generate(filename: str = filename) -> Generator: blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob_data = blob.download_blob() yield from blob_data.chunks() + return generate(filename) def download(self, filename, target_filepath): @@ -62,17 +62,17 @@ def delete(self, filename): blob_container.delete_blob(filename) def _sync_client(self): - cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key) + cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) cache_result = redis_client.get(cache_key) if cache_result is not None: - sas_token = cache_result.decode('utf-8') + sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( account_name=self.account_name, account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 13d9c3429044c8..c3fe9ec82a5b41 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -1,4 +1,5 @@ """Abstract interface for file storage implementations.""" + from abc import ABC, abstractmethod from collections.abc import Generator @@ -6,8 +7,8 @@ class BaseStorage(ABC): - """Interface for file storage. - """ + """Interface for file storage.""" + app = None def __init__(self, app: Flask): diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index ef6cd69039787c..9ed1fcf0b4e118 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -11,16 +11,16 @@ class GoogleStorage(BaseStorage): - """Implementation for google storage. - """ + """Implementation for google storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME') - service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') + self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") + service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") # if service_account_json_str is empty, use Application Default Credentials if service_account_json_str: - service_account_json = base64.b64decode(service_account_json_str).decode('utf-8') + service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) @@ -43,9 +43,10 @@ def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - with closing(blob.open(mode='rb')) as blob_stream: + with closing(blob.open(mode="rb")) as blob_stream: while chunk := blob_stream.read(4096): yield chunk + return generate() def download(self, filename, target_filepath): @@ -60,4 +61,4 @@ def exists(self, filename): def delete(self, filename): bucket = self.client.get_bucket(self.bucket_name) - bucket.delete_blob(filename) \ No newline at end of file + bucket.delete_blob(filename) diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py index 389ef12f82bcd8..46ee4bf80f8e6d 100644 --- a/api/extensions/storage/local_storage.py +++ b/api/extensions/storage/local_storage.py @@ -8,21 +8,20 @@ class LocalStorage(BaseStorage): - """Implementation for local storage. - """ + """Implementation for local storage.""" def __init__(self, app: Flask): super().__init__(app) - folder = self.app.config.get('STORAGE_LOCAL_PATH') + folder = self.app.config.get("STORAGE_LOCAL_PATH") if not os.path.isabs(folder): folder = os.path.join(app.root_path, folder) self.folder = folder def save(self, filename, data): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) @@ -31,10 +30,10 @@ def save(self, filename, data): f.write(data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -46,10 +45,10 @@ def load_once(self, filename: str) -> bytes: def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -61,10 +60,10 @@ def generate(filename: str = filename) -> Generator: return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -72,17 +71,17 @@ def download(self, filename, target_filepath): shutil.copyfile(filename, target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename return os.path.exists(filename) def delete(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if os.path.exists(filename): os.remove(filename) diff --git a/api/extensions/storage/oci_storage.py b/api/extensions/storage/oci_storage.py index e78d870950339b..e32fa0a0ae78a9 100644 --- a/api/extensions/storage/oci_storage.py +++ b/api/extensions/storage/oci_storage.py @@ -12,14 +12,14 @@ class OCIStorage(BaseStorage): def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('OCI_BUCKET_NAME') + self.bucket_name = app_config.get("OCI_BUCKET_NAME") self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('OCI_SECRET_KEY'), - aws_access_key_id=app_config.get('OCI_ACCESS_KEY'), - endpoint_url=app_config.get('OCI_ENDPOINT'), - region_name=app_config.get('OCI_REGION') - ) + "s3", + aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), + aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), + endpoint_url=app_config.get("OCI_ENDPOINT"), + region_name=app_config.get("OCI_REGION"), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -27,9 +27,9 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -40,12 +40,13 @@ def generate(filename: str = filename) -> Generator: try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): @@ -61,4 +62,4 @@ def exists(self, filename): return False def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) \ No newline at end of file + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py index 787596fa791d4a..0858be3af6ad0a 100644 --- a/api/extensions/storage/s3_storage.py +++ b/api/extensions/storage/s3_storage.py @@ -10,24 +10,37 @@ class S3Storage(BaseStorage): - """Implementation for s3 storage. - """ + """Implementation for s3 storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('S3_BUCKET_NAME') - if app_config.get('S3_USE_AWS_MANAGED_IAM'): + self.bucket_name = app_config.get("S3_BUCKET_NAME") + if app_config.get("S3_USE_AWS_MANAGED_IAM"): session = boto3.Session() - self.client = session.client('s3') + self.client = session.client("s3") else: self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('S3_SECRET_KEY'), - aws_access_key_id=app_config.get('S3_ACCESS_KEY'), - endpoint_url=app_config.get('S3_ENDPOINT'), - region_name=app_config.get('S3_REGION'), - config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')}) - ) + "s3", + aws_secret_access_key=app_config.get("S3_SECRET_KEY"), + aws_access_key_id=app_config.get("S3_ACCESS_KEY"), + endpoint_url=app_config.get("S3_ENDPOINT"), + region_name=app_config.get("S3_REGION"), + config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}), + ) + # create bucket + try: + self.client.head_bucket(Bucket=self.bucket_name) + except ClientError as e: + # if bucket not exists, create it + if e.response["Error"]["Code"] == "404": + self.client.create_bucket(Bucket=self.bucket_name) + # if bucket is not accessible, pass, maybe the bucket is existing but not accessible + elif e.response["Error"]["Code"] == "403": + pass + else: + # other error, raise exception + raise def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -35,9 +48,9 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -48,12 +61,13 @@ def generate(filename: str = filename) -> Generator: try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py index e2c1ca55e3c434..1d499cd3bcea1c 100644 --- a/api/extensions/storage/tencent_storage.py +++ b/api/extensions/storage/tencent_storage.py @@ -7,18 +7,17 @@ class TencentStorage(BaseStorage): - """Implementation for tencent cos storage. - """ + """Implementation for tencent cos storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME') + self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") config = CosConfig( - Region=app_config.get('TENCENT_COS_REGION'), - SecretId=app_config.get('TENCENT_COS_SECRET_ID'), - SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'), - Scheme=app_config.get('TENCENT_COS_SCHEME'), + Region=app_config.get("TENCENT_COS_REGION"), + SecretId=app_config.get("TENCENT_COS_SECRET_ID"), + SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), + Scheme=app_config.get("TENCENT_COS_SCHEME"), ) self.client = CosS3Client(config) @@ -26,19 +25,19 @@ def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read() + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].get_stream(chunk_size=4096) + yield from response["Body"].get_stream(chunk_size=4096) return generate() def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - response['Body'].get_stream_to_file(target_filepath) + response["Body"].get_stream_to_file(target_filepath) def exists(self, filename): return self.client.object_exists(Bucket=self.bucket_name, Key=filename) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index c77808447519fc..379dcc6d16fe56 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -5,7 +5,7 @@ annotation_fields = { "id": fields.String, "question": fields.String, - "answer": fields.Raw(attribute='content'), + "answer": fields.Raw(attribute="content"), "hit_count": fields.Integer, "created_at": TimestampField, # 'account': fields.Nested(simple_account_fields, allow_null=True) @@ -21,8 +21,8 @@ "score": fields.Float, "question": fields.String, "created_at": TimestampField, - "match": fields.String(attribute='annotation_question'), - "response": fields.String(attribute='annotation_content') + "match": fields.String(attribute="annotation_question"), + "response": fields.String(attribute="annotation_content"), } annotation_hit_history_list_fields = { diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 749e9900de181d..a85d4a34dbe7b1 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -8,16 +8,16 @@ def output(self, key, obj): api_key = obj.api_key # If the length of the api_key is less than 8 characters, show the first and last characters if len(api_key) <= 8: - return api_key[0] + '******' + api_key[-1] + return api_key[0] + "******" + api_key[-1] # If the api_key is greater than 8 characters, show the first three and the last three characters else: - return api_key[:3] + '******' + api_key[-3:] + return api_key[:3] + "******" + api_key[-3:] api_based_extension_fields = { - 'id': fields.String, - 'name': fields.String, - 'api_endpoint': fields.String, - 'api_key': HiddenAPIKey, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "api_endpoint": fields.String, + "api_key": HiddenAPIKey, + "created_at": TimestampField, } diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 94d804a919869f..45fcb128ceecb7 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,159 +1,187 @@ from flask_restful import fields -from libs.helper import TimestampField +from fields.workflow_fields import workflow_partial_fields +from libs.helper import AppIconUrlField, TimestampField app_detail_kernel_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, } related_app_list = { - 'data': fields.List(fields.Nested(app_detail_kernel_fields)), - 'total': fields.Integer, + "data": fields.List(fields.Nested(app_detail_kernel_fields)), + "total": fields.Integer, } model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), - 'text_to_speech': fields.Raw(attribute='text_to_speech_dict'), - 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), - 'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), - 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'dataset_query_variable': fields.String, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - 'prompt_type': fields.String, - 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), - 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), - 'dataset_configs': fields.Raw(attribute='dataset_configs_dict'), - 'file_upload': fields.Raw(attribute='file_upload_dict'), - 'created_at': TimestampField + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "speech_to_text": fields.Raw(attribute="speech_to_text_dict"), + "text_to_speech": fields.Raw(attribute="text_to_speech_dict"), + "retriever_resource": fields.Raw(attribute="retriever_resource_dict"), + "annotation_reply": fields.Raw(attribute="annotation_reply_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"), + "external_data_tools": fields.Raw(attribute="external_data_tools_list"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "dataset_query_variable": fields.String, + "pre_prompt": fields.String, + "agent_mode": fields.Raw(attribute="agent_mode_dict"), + "prompt_type": fields.String, + "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"), + "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), + "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), + "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } app_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'tracing': fields.Raw, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "tracing": fields.Raw, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } prompt_config_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } model_config_partial_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} app_partial_fields = { - 'id': fields.String, - 'name': fields.String, - 'max_active_requests': fields.Raw(), - 'description': fields.String(attribute='desc_or_prompt'), - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), - 'created_at': TimestampField, - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "max_active_requests": fields.Raw(), + "description": fields.String(attribute="desc_or_prompt"), + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), } app_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), } template_fields = { - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'mode': fields.String, - 'model_config': fields.Nested(model_config_fields), + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, + "model_config": fields.Nested(model_config_fields), } template_list_fields = { - 'data': fields.List(fields.Nested(template_fields)), + "data": fields.List(fields.Nested(template_fields)), } site_fields = { - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'app_base_url': fields.String, - 'show_workflow_steps': fields.Boolean, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } app_detail_fields_with_site = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'site': fields.Nested(site_fields), - 'api_base_url': fields.String, - 'created_at': TimestampField, - 'deleted_tools': fields.List(fields.String), + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "deleted_tools": fields.List(fields.String), } app_site_fields = { - 'app_id': fields.String, - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 79ceb026852792..9207314fc21328 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -6,205 +6,205 @@ class MessageTextField(fields.Raw): def format(self, value): - return value[0]['text'] if value else '' + return value[0]["text"] if value else "" feedback_fields = { - 'rating': fields.String, - 'content': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account': fields.Nested(simple_account_fields, allow_null=True), + "rating": fields.String, + "content": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account": fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { - 'id': fields.String, - 'question': fields.String, - 'content': fields.String, - 'account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "question": fields.String, + "content": fields.String, + "account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } annotation_hit_history_fields = { - 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "annotation_id": fields.String(attribute="id"), + "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } message_file_fields = { - 'id': fields.String, - 'type': fields.String, - 'url': fields.String, - 'belongs_to': fields.String(default='user'), + "id": fields.String, + "type": fields.String, + "url": fields.String, + "belongs_to": fields.String(default="user"), } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String), + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } message_detail_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'message': fields.Raw, - 'message_tokens': fields.Integer, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'answer_tokens': fields.Integer, - 'provider_response_latency': fields.Float, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'feedbacks': fields.List(fields.Nested(feedback_fields)), - 'workflow_run_id': fields.String, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'metadata': fields.Raw(attribute='message_metadata_dict'), - 'status': fields.String, - 'error': fields.String, -} - -feedback_stat_fields = { - 'like': fields.Integer, - 'dislike': fields.Integer -} + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": fields.Raw, + "message_tokens": fields.Integer, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "answer_tokens": fields.Integer, + "provider_response_latency": fields.Float, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "feedbacks": fields.List(fields.Nested(feedback_fields)), + "workflow_run_id": fields.String, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "metadata": fields.Raw(attribute="message_metadata_dict"), + "status": fields.String, + "error": fields.String, +} + +feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'model': fields.Raw, - 'user_input_form': fields.Raw, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw, + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "model": fields.Raw, + "user_input_form": fields.Raw, + "pre_prompt": fields.String, + "agent_mode": fields.Raw, } simple_configs_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } simple_message_detail_fields = { - 'inputs': fields.Raw, - 'query': fields.String, - 'message': MessageTextField, - 'answer': fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": MessageTextField, + "answer": fields.String, } conversation_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String(), - 'from_account_id': fields.String, - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'model_config': fields.Nested(simple_model_config_fields), - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields), - 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String(), + "from_account_id": fields.String, + "from_account_name": fields.String, + "read_at": TimestampField, + "created_at": TimestampField, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "model_config": fields.Nested(simple_model_config_fields), + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), } conversation_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_fields), attribute="items"), } conversation_message_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'model_config': fields.Nested(model_config_fields), - 'message': fields.Nested(message_detail_fields, attribute='first_message'), + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "model_config": fields.Nested(model_config_fields), + "message": fields.Nested(message_detail_fields, attribute="first_message"), } conversation_with_summary_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String, - 'from_account_id': fields.String, - 'name': fields.String, - 'summary': fields.String(attribute='summary_or_query'), - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'model_config': fields.Nested(simple_model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String, + "from_account_id": fields.String, + "from_account_name": fields.String, + "name": fields.String, + "summary": fields.String(attribute="summary_or_query"), + "read_at": TimestampField, + "created_at": TimestampField, + "updated_at": TimestampField, + "annotated": fields.Boolean, + "model_config": fields.Nested(simple_model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } conversation_with_summary_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), } conversation_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'introduction': fields.String, - 'model_config': fields.Nested(model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "annotated": fields.Boolean, + "introduction": fields.String, + "model_config": fields.Nested(model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } simple_conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "inputs": fields.Raw, + "status": fields.String, + "introduction": fields.String, + "created_at": TimestampField, } conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(simple_conversation_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(simple_conversation_fields)), } conversation_with_model_config_fields = { **simple_conversation_fields, - 'model_config': fields.Raw, + "model_config": fields.Raw, } conversation_with_model_config_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_with_model_config_fields)), } diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py new file mode 100644 index 00000000000000..983e50e73ceb9f --- /dev/null +++ b/api/fields/conversation_variable_fields.py @@ -0,0 +1,21 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +conversation_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.String, + "description": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +paginated_conversation_variable_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"), +} diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 6f3c920c85b60f..071071376fe6c8 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -2,64 +2,56 @@ from libs.helper import TimestampField -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'is_bound': fields.Boolean, - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)) + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), } integrate_notion_info_list_fields = { - 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), + "notion_info": fields.List(fields.Nested(integrate_workspace_fields)), } -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)), - 'total': fields.Integer + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), + "total": fields.Integer, } integrate_fields = { - 'id': fields.String, - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'disabled': fields.Boolean, - 'link': fields.String, - 'source_info': fields.Nested(integrate_workspace_fields) + "id": fields.String, + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "disabled": fields.Boolean, + "link": fields.String, + "source_info": fields.Nested(integrate_workspace_fields), } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), -} \ No newline at end of file + "data": fields.List(fields.Nested(integrate_fields)), +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a9f79b5c678e7c..9cf8da7acdc984 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -3,73 +3,64 @@ from libs.helper import TimestampField dataset_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "created_by": fields.String, + "created_at": TimestampField, } -reranking_model_fields = { - 'reranking_provider_name': fields.String, - 'reranking_model_name': fields.String -} +reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String} -keyword_setting_fields = { - 'keyword_weight': fields.Float -} +keyword_setting_fields = {"keyword_weight": fields.Float} vector_setting_fields = { - 'vector_weight': fields.Float, - 'embedding_model_name': fields.String, - 'embedding_provider_name': fields.String, + "vector_weight": fields.Float, + "embedding_model_name": fields.String, + "embedding_provider_name": fields.String, } weighted_score_fields = { - 'keyword_setting': fields.Nested(keyword_setting_fields), - 'vector_setting': fields.Nested(vector_setting_fields), + "keyword_setting": fields.Nested(keyword_setting_fields), + "vector_setting": fields.Nested(vector_setting_fields), } dataset_retrieval_model_fields = { - 'search_method': fields.String, - 'reranking_enable': fields.Boolean, - 'reranking_mode': fields.String, - 'reranking_model': fields.Nested(reranking_model_fields), - 'weights': fields.Nested(weighted_score_fields, allow_null=True), - 'top_k': fields.Integer, - 'score_threshold_enabled': fields.Boolean, - 'score_threshold': fields.Float + "search_method": fields.String, + "reranking_enable": fields.Boolean, + "reranking_mode": fields.String, + "reranking_model": fields.Nested(reranking_model_fields), + "weights": fields.Nested(weighted_score_fields, allow_null=True), + "top_k": fields.Integer, + "score_threshold_enabled": fields.Boolean, + "score_threshold": fields.Float, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} dataset_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'provider': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'app_count': fields.Integer, - 'document_count': fields.Integer, - 'word_count': fields.Integer, - 'created_by': fields.String, - 'created_at': TimestampField, - 'updated_by': fields.String, - 'updated_at': TimestampField, - 'embedding_model': fields.String, - 'embedding_model_provider': fields.String, - 'embedding_available': fields.Boolean, - 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields), - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "description": fields.String, + "provider": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "app_count": fields.Integer, + "document_count": fields.Integer, + "word_count": fields.Integer, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "embedding_model": fields.String, + "embedding_model_provider": fields.String, + "embedding_available": fields.Boolean, + "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "tags": fields.List(fields.Nested(tag_fields)), } dataset_query_detail_fields = { @@ -79,7 +70,5 @@ "source_app_id": fields.String, "created_by_role": fields.String, "created_by": fields.String, - "created_at": TimestampField + "created_at": TimestampField, } - - diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index e8215255b35d5b..a83ec7bc97adee 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -4,75 +4,73 @@ from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'doc_form': fields.String, + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "doc_form": fields.String, } document_with_segments_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } dataset_and_document_fields = { - 'dataset': fields.Nested(dataset_fields), - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String + "dataset": fields.Nested(dataset_fields), + "documents": fields.List(fields.Nested(document_fields)), + "batch": fields.String, } document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, + "id": fields.String, + "indexing_status": fields.String, + "processing_started_at": TimestampField, + "parsing_completed_at": TimestampField, + "cleaning_completed_at": TimestampField, + "splitting_completed_at": TimestampField, + "completed_at": TimestampField, + "paused_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } -document_status_fields_list = { - 'data': fields.List(fields.Nested(document_status_fields)) -} \ No newline at end of file +document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))} diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ee630c12c2e9aa..99e529f9d1c076 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,8 +1,8 @@ from flask_restful import fields simple_end_user_fields = { - 'id': fields.String, - 'type': fields.String, - 'is_anonymous': fields.Boolean, - 'session_id': fields.String, + "id": fields.String, + "type": fields.String, + "is_anonymous": fields.Boolean, + "session_id": fields.String, } diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc0d08..e5a03ce77ed5f0 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -3,17 +3,17 @@ from libs.helper import TimestampField upload_config_fields = { - 'file_size_limit': fields.Integer, - 'batch_count_limit': fields.Integer, - 'image_file_size_limit': fields.Integer, + "file_size_limit": fields.Integer, + "batch_count_limit": fields.Integer, + "image_file_size_limit": fields.Integer, } file_fields = { - 'id': fields.String, - 'name': fields.String, - 'size': fields.Integer, - 'extension': fields.String, - 'mime_type': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, -} \ No newline at end of file + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 541e56a378dae4..f36e80f8d493d5 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -3,39 +3,39 @@ from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'data_source_type': fields.String, - 'name': fields.String, - 'doc_type': fields.String, + "id": fields.String, + "data_source_type": fields.String, + "name": fields.String, + "doc_type": fields.String, } segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'document': fields.Nested(document_fields), + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "document": fields.Nested(document_fields), } hit_testing_record_fields = { - 'segment': fields.Nested(segment_fields), - 'score': fields.Float, - 'tsne_position': fields.Raw -} \ No newline at end of file + "segment": fields.Nested(segment_fields), + "score": fields.Float, + "tsne_position": fields.Raw, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 35cc5a64755eca..9afc1b1a4ad79d 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,25 +1,25 @@ from flask_restful import fields -from libs.helper import TimestampField +from libs.helper import AppIconUrlField, TimestampField app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, } installed_app_fields = { - 'id': fields.String, - 'app': fields.Nested(app_fields), - 'app_owner_tenant_id': fields.String, - 'is_pinned': fields.Boolean, - 'last_used_at': TimestampField, - 'editable': fields.Boolean, - 'uninstallable': fields.Boolean + "id": fields.String, + "app": fields.Nested(app_fields), + "app_owner_tenant_id": fields.String, + "is_pinned": fields.Boolean, + "last_used_at": TimestampField, + "editable": fields.Boolean, + "uninstallable": fields.Boolean, } -installed_app_list_fields = { - 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} \ No newline at end of file +installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))} diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index d061b59c347022..1cf8e408d13d32 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -2,38 +2,32 @@ from libs.helper import TimestampField -simple_account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} +simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "is_password_set": fields.Boolean, + "interface_language": fields.String, + "interface_theme": fields.String, + "timezone": fields.String, + "last_login_at": TimestampField, + "last_login_ip": fields.String, + "created_at": TimestampField, } account_with_role_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'last_active_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "last_login_at": TimestampField, + "last_active_at": TimestampField, + "created_at": TimestampField, + "role": fields.String, + "status": fields.String, } -account_with_role_list_fields = { - 'accounts': fields.List(fields.Nested(account_with_role_fields)) -} +account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 31168435892427..3d2df87afb9b19 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,83 +3,79 @@ from fields.conversation_fields import message_file_fields from libs.helper import TimestampField -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String) + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index e41d1a53dd0234..2dd4cb45be409b 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -3,31 +3,31 @@ from libs.helper import TimestampField segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, } segment_list_response = { - 'data': fields.List(fields.Nested(segment_fields)), - 'has_more': fields.Boolean, - 'limit': fields.Integer + "data": fields.List(fields.Nested(segment_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, } diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index f7e030b738e537..9af4fc57dd061c 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,8 +1,3 @@ from flask_restful import fields -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String, - 'binding_count': fields.String -} \ No newline at end of file +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index e230c159fba59a..a53b54624915c2 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -7,18 +7,18 @@ workflow_app_log_partial_fields = { "id": fields.String, - "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), "created_from": fields.String, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "created_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, } workflow_app_log_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"), } diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index ff33a97ff2a9ab..2adef63ada6fb9 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,41 +13,51 @@ def format(self, value): # Mask secret variables values in environment_variables if isinstance(value, SecretVariable): return { - 'id': value.id, - 'name': value.name, - 'value': encrypter.obfuscated_token(value.value), - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": encrypter.obfuscated_token(value.value), + "value_type": value.value_type.value, } if isinstance(value, Variable): return { - 'id': value.id, - 'name': value.name, - 'value': value.value, - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": value.value, + "value_type": value.value_type.value, } if isinstance(value, dict): - value_type = value.get('value_type') + value_type = value.get("value_type") if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: - raise ValueError(f'Unsupported environment variable value type: {value_type}') + raise ValueError(f"Unsupported environment variable value type: {value_type}") return value -environment_variable_fields = { - 'id': fields.String, - 'name': fields.String, - 'value': fields.Raw, - 'value_type': fields.String(attribute='value_type.value'), +conversation_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, + "description": fields.String, } workflow_fields = { - 'id': fields.String, - 'graph': fields.Raw(attribute='graph_dict'), - 'features': fields.Raw(attribute='features_dict'), - 'hash': fields.String(attribute='unique_hash'), - 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), - 'created_at': TimestampField, - 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), - 'updated_at': TimestampField, - 'tool_published': fields.Boolean, - 'environment_variables': fields.List(EnvironmentVariableField()), + "id": fields.String, + "graph": fields.Raw(attribute="graph_dict"), + "features": fields.Raw(attribute="features_dict"), + "hash": fields.String(attribute="unique_hash"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, + "tool_published": fields.Boolean, + "environment_variables": fields.List(EnvironmentVariableField()), + "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), +} + +workflow_partial_fields = { + "id": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 3e798473cd0481..1413adf7196879 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -13,7 +13,7 @@ "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_for_list_fields = { @@ -24,9 +24,9 @@ "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_for_list_fields = { @@ -39,40 +39,40 @@ "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), } workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), } workflow_run_detail_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.Raw(attribute='graph_dict'), - "inputs": fields.Raw(attribute='inputs_dict'), + "graph": fields.Raw(attribute="graph_dict"), + "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, - "outputs": fields.Raw(attribute='outputs_dict'), + "outputs": fields.Raw(attribute="outputs_dict"), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_node_execution_fields = { @@ -82,21 +82,21 @@ "node_id": fields.String, "node_type": fields.String, "title": fields.String, - "inputs": fields.Raw(attribute='inputs_dict'), - "process_data": fields.Raw(attribute='process_data_dict'), - "outputs": fields.Raw(attribute='outputs_dict'), + "inputs": fields.Raw(attribute="inputs_dict"), + "process_data": fields.Raw(attribute="process_data_dict"), + "outputs": fields.Raw(attribute="outputs_dict"), "status": fields.String, "error": fields.String, "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), "extras": fields.Raw, "created_at": TimestampField, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "finished_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "finished_at": TimestampField, } workflow_run_node_execution_list_fields = { - 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), + "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } diff --git a/api/libs/bearer_data_source.py b/api/libs/bearer_data_source.py deleted file mode 100644 index 04de1fb6daefbd..00000000000000 --- a/api/libs/bearer_data_source.py +++ /dev/null @@ -1,64 +0,0 @@ -# [REVIEW] Implement if Needed? Do we need a new type of data source -from abc import abstractmethod - -import requests -from api.models.source import DataSourceBearerBinding -from flask_login import current_user - -from extensions.ext_database import db - - -class BearerDataSource: - def __init__(self, api_key: str, api_base_url: str): - self.api_key = api_key - self.api_base_url = api_base_url - - @abstractmethod - def validate_bearer_data_source(self): - """ - Validate the data source - """ - - -class FireCrawlDataSource(BearerDataSource): - def validate_bearer_data_source(self): - TEST_CRAWL_SITE_URL = "https://www.google.com" - FIRECRAWL_API_VERSION = "v0" - - test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - data = { - "url": TEST_CRAWL_SITE_URL, - } - - response = requests.get(test_api_endpoint, headers=headers, json=data) - - return response.json().get("status") == "success" - - def save_credentials(self): - # save data source binding - data_source_binding = DataSourceBearerBinding.query.filter( - db.and_( - DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, - DataSourceBearerBinding.provider == 'firecrawl', - DataSourceBearerBinding.endpoint_url == self.api_base_url, - DataSourceBearerBinding.bearer_key == self.api_key - ) - ).first() - if data_source_binding: - data_source_binding.disabled = False - db.session.commit() - else: - new_data_source_binding = DataSourceBearerBinding( - tenant_id=current_user.current_tenant_id, - provider='firecrawl', - endpoint_url=self.api_base_url, - bearer_key=self.api_key - ) - db.session.add(new_data_source_binding) - db.session.commit() diff --git a/api/libs/exception.py b/api/libs/exception.py index 567062f064fa97..5970269ecdbed5 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -4,7 +4,7 @@ class BaseHTTPException(HTTPException): - error_code: str = 'unknown' + error_code: str = "unknown" data: Optional[dict] = None def __init__(self, description=None, response=None): @@ -14,4 +14,4 @@ def __init__(self, description=None, response=None): "code": self.error_code, "message": self.description, "status": self.code, - } \ No newline at end of file + } diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 677ff0fc5b6aab..179617ac0a6588 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -10,7 +10,6 @@ class ExternalApi(Api): - def handle_error(self, e): """Error handler for the API transforms a raised exception into a Flask response, with the appropriate HTTP status code and body. @@ -29,54 +28,57 @@ def handle_error(self, e): status_code = e.code default_data = { - 'code': re.sub(r'(? self.max_length: - error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' - .format(arg=self.argument, val=value, length=self.max_length)) + error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( + arg=self.argument, val=value, length=self.max_length + ) raise ValueError(error) return value class float_range: - """ Restrict input to an float in a range (inclusive) """ - def __init__(self, low, high, argument='argument'): + """Restrict input to an float in a range (inclusive)""" + + def __init__(self, low, high, argument="argument"): self.low = low self.high = high self.argument = argument @@ -99,15 +113,16 @@ def __init__(self, low, high, argument='argument'): def __call__(self, value): value = _get_float(value) if value < self.low or value > self.high: - error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' - .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) + error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( + arg=self.argument, val=value, lo=self.low, hi=self.high + ) raise ValueError(error) return value class datetime_string: - def __init__(self, format, argument='argument'): + def __init__(self, format, argument="argument"): self.format = format self.argument = argument @@ -115,8 +130,9 @@ def __call__(self, value): try: datetime.strptime(value, self.format) except ValueError: - error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' - .format(arg=self.argument, val=value, format=self.format)) + error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( + arg=self.argument, val=value, format=self.format + ) raise ValueError(error) return value @@ -126,14 +142,14 @@ def _get_float(value): try: return float(value) except (TypeError, ValueError): - raise ValueError('{} is not a valid float'.format(value)) + raise ValueError("{} is not a valid float".format(value)) + def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string - error = ('{timezone_string} is not a valid timezone.' - .format(timezone_string=timezone_string)) + error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) raise ValueError(error) @@ -147,8 +163,8 @@ def generate_string(n): def get_remote_ip(request) -> str: - if request.headers.get('CF-Connecting-IP'): - return request.headers.get('Cf-Connecting-Ip') + if request.headers.get("CF-Connecting-IP"): + return request.headers.get("Cf-Connecting-Ip") elif request.headers.getlist("X-Forwarded-For"): return request.headers.getlist("X-Forwarded-For")[0] else: @@ -156,54 +172,45 @@ def get_remote_ip(request) -> str: def generate_text_hash(text: str) -> str: - hash_text = str(text) + 'None' + hash_text = str(text) + "None" return sha256(hash_text.encode()).hexdigest() def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') + return Response(response=json.dumps(response), status=200, mimetype="application/json") else: + def generate() -> Generator: yield from response - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') + return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") class TokenManager: - @classmethod def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: old_token = cls._get_current_token_for_account(account.id, token_type) if old_token: if isinstance(old_token, bytes): - old_token = old_token.decode('utf-8') + old_token = old_token.decode("utf-8") cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = { - 'account_id': account.id, - 'email': account.email, - 'token_type': token_type - } + token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} if additional_data: token_data.update(additional_data) - expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] + expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] token_key = cls._get_token_key(token, token_type) - redis_client.setex( - token_key, - expiry_hours * 60 * 60, - json.dumps(token_data) - ) + redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @classmethod def _get_token_key(cls, token: str, token_type: str) -> str: - return f'{token_type}:token:{token}' + return f"{token_type}:token:{token}" @classmethod def revoke_token(cls, token: str, token_type: str): @@ -233,7 +240,7 @@ def _set_current_token_for_account(cls, account_id: str, token: str, token_type: @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: - return f'{token_type}:account:{account_id}' + return f"{token_type}:account:{account_id}" class RateLimiter: @@ -250,7 +257,7 @@ def is_rate_limited(self, email: str) -> bool: current_time = int(time.time()) window_start_time = current_time - self.time_window - redis_client.zremrangebyscore(key, '-inf', window_start_time) + redis_client.zremrangebyscore(key, "-inf", window_start_time) attempts = redis_client.zcard(key) if attempts and int(attempts) >= self.max_attempts: diff --git a/api/libs/infinite_scroll_pagination.py b/api/libs/infinite_scroll_pagination.py index a1cb7b78fc008c..133ccb188338e2 100644 --- a/api/libs/infinite_scroll_pagination.py +++ b/api/libs/infinite_scroll_pagination.py @@ -1,4 +1,3 @@ - class InfiniteScrollPagination: def __init__(self, data, limit, has_more): self.data = data diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 2cf023a39922be..41d69058992f4b 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,13 +10,13 @@ def parse_json_markdown(json_string: str) -> dict: end_index = json_string.find("```", start_index + len("```json")) if start_index != -1 and end_index != -1: - extracted_content = json_string[start_index + len("```json"):end_index].strip() + extracted_content = json_string[start_index + len("```json") : end_index].strip() # Parse the JSON string into a Python dictionary parsed = json.loads(extracted_content) elif start_index != -1 and end_index == -1 and json_string.endswith("``"): end_index = json_string.find("``", start_index + len("```json")) - extracted_content = json_string[start_index + len("```json"):end_index].strip() + extracted_content = json_string[start_index + len("```json") : end_index].strip() # Parse the JSON string into a Python dictionary parsed = json.loads(extracted_content) @@ -37,7 +37,6 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: for key in expected_keys: if key not in json_obj: raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " - f"to be present, but got {json_obj}" + f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/login.py b/api/libs/login.py index 14085fe6035305..7f05eb8404a0ba 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -51,27 +51,29 @@ def post(): @wraps(func) def decorated_view(*args, **kwargs): - auth_header = request.headers.get('Authorization') - admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False') - if admin_api_key_enable.lower() == 'true': + auth_header = request.headers.get("Authorization") + admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") + if admin_api_key_enable.lower() == "true": if auth_header: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - admin_api_key = os.getenv('ADMIN_API_KEY') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + admin_api_key = os.getenv("ADMIN_API_KEY") if admin_api_key: - if os.getenv('ADMIN_API_KEY') == auth_token: - workspace_id = request.headers.get('X-WORKSPACE-ID') + if os.getenv("ADMIN_API_KEY") == auth_token: + workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == workspace_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role == 'owner') \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == workspace_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role == "owner") .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() diff --git a/api/libs/oauth.py b/api/libs/oauth.py index dacdee0bc1666f..d8ce1a1e6633e6 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -35,31 +35,31 @@ def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: class GitHubOAuth(OAuth): - _AUTH_URL = 'https://github.com/login/oauth/authorize' - _TOKEN_URL = 'https://github.com/login/oauth/access_token' - _USER_INFO_URL = 'https://api.github.com/user' - _EMAIL_INFO_URL = 'https://api.github.com/user/emails' + _AUTH_URL = "https://github.com/login/oauth/authorize" + _TOKEN_URL = "https://github.com/login/oauth/access_token" + _USER_INFO_URL = "https://api.github.com/user" + _EMAIL_INFO_URL = "https://api.github.com/user/emails" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'redirect_uri': self.redirect_uri, - 'scope': 'user:email' # Request only basic user information + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": "user:email", # Request only basic user information } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'redirect_uri': self.redirect_uri + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, } - headers = {'Accept': 'application/json'} + headers = {"Accept": "application/json"} response = requests.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in GitHub OAuth: {response_json}") @@ -67,55 +67,51 @@ def get_access_token(self, code: str): return access_token def get_raw_user_info(self, token: str): - headers = {'Authorization': f"token {token}"} + headers = {"Authorization": f"token {token}"} response = requests.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() user_info = response.json() email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email['primary'] == True), None) + primary_email = next((email for email in email_info if email["primary"] == True), None) - return {**user_info, 'email': primary_email['email']} + return {**user_info, "email": primary_email["email"]} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - email = raw_info.get('email') + email = raw_info.get("email") if not email: email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" - return OAuthUserInfo( - id=str(raw_info['id']), - name=raw_info['name'], - email=email - ) + return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) class GoogleOAuth(OAuth): - _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' - _TOKEN_URL = 'https://oauth2.googleapis.com/token' - _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' + _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _TOKEN_URL = "https://oauth2.googleapis.com/token" + _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'scope': 'openid email' + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": "openid email", } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, } - headers = {'Accept': 'application/json'} + headers = {"Accept": "application/json"} response = requests.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in Google OAuth: {response_json}") @@ -123,16 +119,10 @@ def get_access_token(self, code: str): return access_token def get_raw_user_info(self, token: str): - headers = {'Authorization': f"Bearer {token}"} + headers = {"Authorization": f"Bearer {token}"} response = requests.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo( - id=str(raw_info['sub']), - name=None, - email=raw_info['email'] - ) - - + return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index a5c7814a543bdc..6da1a6d39bfd57 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -21,53 +21,49 @@ def get_access_token(self, code: str): class NotionOAuth(OAuthDataSource): - _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' - _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' + _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" + _TOKEN_URL = "https://api.notion.com/v1/oauth/token" _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'owner': 'user' + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "owner": "user", } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): - data = { - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri - } - headers = {'Accept': 'application/json'} + data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} + headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in Notion OAuth: {response_json}") - workspace_name = response_json.get('workspace_name') - workspace_icon = response_json.get('workspace_icon') - workspace_id = response_json.get('workspace_id') + workspace_name = response_json.get("workspace_name") + workspace_icon = response_json.get("workspace_icon") + workspace_id = response_json.get("workspace_id") # get all authorized pages pages = self.get_authorized_pages(access_token) source_info = { - 'workspace_name': workspace_name, - 'workspace_icon': workspace_icon, - 'workspace_id': workspace_id, - 'pages': pages, - 'total': len(pages) + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), } # save data source binding data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.access_token == access_token + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) ).first() if data_source_binding: @@ -79,7 +75,7 @@ def get_access_token(self, code: str): tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, - provider='notion' + provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() @@ -91,18 +87,18 @@ def save_internal_access_token(self, access_token: str): # get all authorized pages pages = self.get_authorized_pages(access_token) source_info = { - 'workspace_name': workspace_name, - 'workspace_icon': workspace_icon, - 'workspace_id': workspace_id, - 'pages': pages, - 'total': len(pages) + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), } # save data source binding data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.access_token == access_token + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) ).first() if data_source_binding: @@ -114,7 +110,7 @@ def save_internal_access_token(self, access_token: str): tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, - provider='notion' + provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() @@ -124,9 +120,9 @@ def sync_data_source(self, binding_id: str): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False + DataSourceOauthBinding.disabled == False, ) ).first() if data_source_binding: @@ -134,17 +130,17 @@ def sync_data_source(self, binding_id: str): pages = self.get_authorized_pages(data_source_binding.access_token) source_info = data_source_binding.source_info new_source_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, - 'total': len(pages) + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, + "total": len(pages), } data_source_binding.source_info = new_source_info data_source_binding.disabled = False db.session.commit() else: - raise ValueError('Data source binding not found') + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str): pages = [] @@ -152,143 +148,121 @@ def get_authorized_pages(self, access_token: str): database_results = self.notion_database_search(access_token) # get page detail for page_result in page_results: - page_id = page_result['id'] - page_name = 'Untitled' - for key in ['Name', 'title', 'Title', 'Page']: - if key in page_result['properties']: - if len(page_result['properties'][key].get('title', [])) > 0: - page_name = page_result['properties'][key]['title'][0]['plain_text'] - break - page_icon = page_result['icon'] + page_id = page_result["id"] + page_name = "Untitled" + for key in page_result["properties"]: + if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: + title_list = page_result["properties"][key]["title"] + if len(title_list) > 0 and "plain_text" in title_list[0]: + page_name = title_list[0]["plain_text"] + page_icon = page_result["icon"] if page_icon: - icon_type = page_icon['type'] - if icon_type == 'external' or icon_type == 'file': - url = page_icon[icon_type]['url'] - icon = { - 'type': 'url', - 'url': url if url.startswith('http') else f'https://www.notion.so{url}' - } + icon_type = page_icon["type"] + if icon_type == "external" or icon_type == "file": + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: - icon = { - 'type': 'emoji', - 'emoji': page_icon[icon_type] - } + icon = {"type": "emoji", "emoji": page_icon[icon_type]} else: icon = None - parent = page_result['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = page_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) - elif parent_type == 'workspace': - parent_id = 'root' + elif parent_type == "workspace": + parent_id = "root" else: parent_id = parent[parent_type] page = { - 'page_id': page_id, - 'page_name': page_name, - 'page_icon': icon, - 'parent_id': parent_id, - 'type': 'page' + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "page", } pages.append(page) # get database detail for database_result in database_results: - page_id = database_result['id'] - if len(database_result['title']) > 0: - page_name = database_result['title'][0]['plain_text'] + page_id = database_result["id"] + if len(database_result["title"]) > 0: + page_name = database_result["title"][0]["plain_text"] else: - page_name = 'Untitled' - page_icon = database_result['icon'] + page_name = "Untitled" + page_icon = database_result["icon"] if page_icon: - icon_type = page_icon['type'] - if icon_type == 'external' or icon_type == 'file': - url = page_icon[icon_type]['url'] - icon = { - 'type': 'url', - 'url': url if url.startswith('http') else f'https://www.notion.so{url}' - } + icon_type = page_icon["type"] + if icon_type == "external" or icon_type == "file": + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: - icon = { - 'type': icon_type, - icon_type: page_icon[icon_type] - } + icon = {"type": icon_type, icon_type: page_icon[icon_type]} else: icon = None - parent = database_result['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = database_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) - elif parent_type == 'workspace': - parent_id = 'root' + elif parent_type == "workspace": + parent_id = "root" else: parent_id = parent[parent_type] page = { - 'page_id': page_id, - 'page_name': page_name, - 'page_icon': icon, - 'parent_id': parent_id, - 'type': 'database' + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "database", } pages.append(page) return pages def notion_page_search(self, access_token: str): - data = { - 'filter': { - "value": "page", - "property": "object" - } - } + data = {"filter": {"value": "page", "property": "object"}} headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - results = response_json.get('results', []) + results = response_json.get("results", []) return results def notion_block_parent_page_id(self, access_token: str, block_id: str): headers = { - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } - response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) + response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() - parent = response_json['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = response_json["parent"] + parent_type = parent["type"] + if parent_type == "block_id": return self.notion_block_parent_page_id(access_token, parent[parent_type]) return parent[parent_type] def notion_workspace_name(self, access_token: str): headers = { - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.get(url=self._NOTION_BOT_USER, headers=headers) response_json = response.json() - if 'object' in response_json and response_json['object'] == 'user': - user_type = response_json['type'] + if "object" in response_json and response_json["object"] == "user": + user_type = response_json["type"] user_info = response_json[user_type] - if 'workspace_name' in user_info: - return user_info['workspace_name'] - return 'workspace' + if "workspace_name" in user_info: + return user_info["workspace_name"] + return "workspace" def notion_database_search(self, access_token: str): - data = { - 'filter': { - "value": "database", - "property": "object" - } - } + data = {"filter": {"value": "database", "property": "object"}} headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - results = response_json.get('results', []) + results = response_json.get("results", []) return results diff --git a/api/libs/passport.py b/api/libs/passport.py index 34bdc55997cc0f..8df4f529bc3898 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -9,14 +9,14 @@ def __init__(self): self.sk = dify_config.SECRET_KEY def issue(self, payload): - return jwt.encode(payload, self.sk, algorithm='HS256') + return jwt.encode(payload, self.sk, algorithm="HS256") def verify(self, token): try: - return jwt.decode(token, self.sk, algorithms=['HS256']) + return jwt.decode(token, self.sk, algorithms=["HS256"]) except jwt.exceptions.InvalidSignatureError: - raise Unauthorized('Invalid token signature.') + raise Unauthorized("Invalid token signature.") except jwt.exceptions.DecodeError: - raise Unauthorized('Invalid token.') + raise Unauthorized("Invalid token.") except jwt.exceptions.ExpiredSignatureError: - raise Unauthorized('Token has expired.') + raise Unauthorized("Token has expired.") diff --git a/api/libs/password.py b/api/libs/password.py index cdd1d69dbf81e8..cfcc0db22dded6 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -5,6 +5,7 @@ password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" + def valid_password(password): # Define a regex pattern for password rules pattern = password_pattern @@ -12,11 +13,11 @@ def valid_password(password): if re.match(pattern, password) is not None: return password - raise ValueError('Not a valid password.') + raise ValueError("Not a valid password.") def hash_password(password_str, salt_byte): - dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) + dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) return binascii.hexlify(dk) diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 9f29c58812526a..a578bf3e5617f6 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -48,7 +48,7 @@ def encrypt(text, public_key): def get_decrypt_decoding(tenant_id): filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" - cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) + cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) private_key = redis_client.get(cache_key) if not private_key: try: @@ -66,12 +66,12 @@ def get_decrypt_decoding(tenant_id): def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): if encrypted_text.startswith(prefix_hybrid): - encrypted_text = encrypted_text[len(prefix_hybrid):] + encrypted_text = encrypted_text[len(prefix_hybrid) :] - enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] - nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] - tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] - ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] + enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] + nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] + tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] + ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] aes_key = cipher_rsa.decrypt(enc_aes_key) diff --git a/api/libs/smtp.py b/api/libs/smtp.py index bf3a1a92e9eb52..bd7de7dd689a7a 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -5,7 +5,9 @@ class SMTPClient: - def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): + def __init__( + self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False + ): self.server = server self.port = port self._from = _from @@ -25,17 +27,17 @@ def send(self, mail: dict): smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: smtp = smtplib.SMTP(self.server, self.port, timeout=10) - + if self.username and self.password: smtp.login(self.username, self.password) msg = MIMEMultipart() - msg['Subject'] = mail['subject'] - msg['From'] = self._from - msg['To'] = mail['to'] - msg.attach(MIMEText(mail['html'], 'html')) + msg["Subject"] = mail["subject"] + msg["From"] = self._from + msg["To"] = mail["to"] + msg.attach(MIMEText(mail["html"], "html")) - smtp.sendmail(self._from, mail['to'], msg.as_string()) + smtp.sendmail(self._from, mail["to"], msg.as_string()) except smtplib.SMTPException as e: logging.error(f"SMTP error occurred: {str(e)}") raise diff --git a/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py b/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py new file mode 100644 index 00000000000000..db966252f1a63c --- /dev/null +++ b/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py @@ -0,0 +1,39 @@ +"""update tools original_url length + +Revision ID: 1787fbae959a +Revises: eeb2e349e6ac +Create Date: 2024-08-09 08:01:12.817620 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '1787fbae959a' +down_revision = 'eeb2e349e6ac' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('original_url', + existing_type=sa.VARCHAR(length=255), + type_=sa.String(length=2048), + existing_nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('original_url', + existing_type=sa.String(length=2048), + type_=sa.VARCHAR(length=255), + existing_nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py new file mode 100644 index 00000000000000..16e1efd4efd4ed --- /dev/null +++ b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py @@ -0,0 +1,51 @@ +"""support conversation variables + +Revision ID: 63a83fcf12ba +Revises: 1787fbae959a +Create Date: 2024-08-13 06:33:07.950379 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '63a83fcf12ba' +down_revision = '1787fbae959a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow__conversation_variables', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('conversation_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('data', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey')) + ) + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.create_index(batch_op.f('workflow__conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow__conversation_variables_created_at_idx'), ['created_at'], unique=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('conversation_variables') + + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow__conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow__conversation_variables_app_id_idx')) + + op.drop_table('workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py new file mode 100644 index 00000000000000..eba78e2e77d5d8 --- /dev/null +++ b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py @@ -0,0 +1,33 @@ +"""add conversations.dialogue_count + +Revision ID: 8782057ff0dc +Revises: 63a83fcf12ba +Create Date: 2024-08-14 13:54:25.161324 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8782057ff0dc' +down_revision = '63a83fcf12ba' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_column('dialogue_count') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py new file mode 100644 index 00000000000000..d814666eefd2f2 --- /dev/null +++ b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py @@ -0,0 +1,39 @@ +"""app and site icon type + +Revision ID: a6be81136580 +Revises: 8782057ff0dc +Create Date: 2024-08-15 10:01:24.697888 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'a6be81136580' +down_revision = '8782057ff0dc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py new file mode 100644 index 00000000000000..3dc7fed818ea2b --- /dev/null +++ b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py @@ -0,0 +1,28 @@ +"""rename workflow__conversation_variables to workflow_conversation_variables + +Revision ID: 2dbe42621d96 +Revises: a6be81136580 +Create Date: 2024-08-20 04:55:38.160010 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '2dbe42621d96' +down_revision = 'a6be81136580' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py new file mode 100644 index 00000000000000..e0066a302cd5a2 --- /dev/null +++ b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py @@ -0,0 +1,52 @@ +"""add created_by and updated_by to app, modelconfig, and site + +Revision ID: d0187d6a88dd +Revises: 2dbe42621d96 +Create Date: 2024-08-25 04:41:18.157397 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "d0187d6a88dd" +down_revision = "2dbe42621d96" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 3b832cd22d8120..4012611471c337 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,15 +1,19 @@ from enum import Enum -from sqlalchemy import CHAR, TypeDecorator -from sqlalchemy.dialects.postgresql import UUID +from .model import App, AppMode, Message +from .types import StringUUID +from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus + +__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] class CreatedByRole(Enum): """ Enum class for createdByRole """ - ACCOUNT = "account" - END_USER = "end_user" + + ACCOUNT = 'account' + END_USER = 'end_user' @classmethod def value_of(cls, value: str) -> 'CreatedByRole': @@ -23,49 +27,3 @@ def value_of(cls, value: str) -> 'CreatedByRole': if role.value == value: return role raise ValueError(f'invalid createdByRole value {value}') - - -class CreatedFrom(Enum): - """ - Enum class for createdFrom - """ - SERVICE_API = "service-api" - WEB_APP = "web-app" - EXPLORE = "explore" - - @classmethod - def value_of(cls, value: str) -> 'CreatedFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for role in cls: - if role.value == value: - return role - raise ValueError(f'invalid createdFrom value {value}') - - -class StringUUID(TypeDecorator): - impl = CHAR - cache_ok = True - - def process_bind_param(self, value, dialect): - if value is None: - return value - elif dialect.name == 'postgresql': - return str(value) - else: - return value.hex - - def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': - return dialect.type_descriptor(UUID()) - else: - return dialect.type_descriptor(CHAR(36)) - - def process_result_value(self, value, dialect): - if value is None: - return value - return str(value) diff --git a/api/models/account.py b/api/models/account.py index d36b2b9fda3278..67d940b7b7190e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,8 @@ from flask_login import UserMixin from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class AccountStatus(str, enum.Enum): diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index d1f9cd78a72e45..7f69323628a7cc 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class APIBasedExtensionPoint(enum.Enum): diff --git a/api/models/dataset.py b/api/models/dataset.py index 40f9f4cf83ae96..203031c7b98587 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1,4 +1,5 @@ import base64 +import enum import hashlib import hmac import json @@ -16,10 +17,16 @@ from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage -from models import StringUUID -from models.account import Account -from models.model import App, Tag, TagBinding, UploadFile +from .account import Account +from .model import App, Tag, TagBinding, UploadFile +from .types import StringUUID + + +class DatasetPermissionEnum(str, enum.Enum): + ONLY_ME = 'only_me' + ALL_TEAM = 'all_team_members' + PARTIAL_TEAM = 'partial_members' class Dataset(db.Model): __tablename__ = 'datasets' diff --git a/api/models/model.py b/api/models/model.py index a6f517ea6b181f..e2d1fcfc237a4d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,6 +7,7 @@ from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.file.tool_file_parser import ToolFileParser @@ -14,8 +15,8 @@ from extensions.ext_database import db from libs.helper import generate_string -from . import StringUUID from .account import Account, Tenant +from .types import StringUUID class DifySetup(db.Model): @@ -50,6 +51,10 @@ def value_of(cls, value: str) -> 'AppMode': raise ValueError(f'invalid mode value {value}') +class IconType(Enum): + IMAGE = "image" + EMOJI = "emoji" + class App(db.Model): __tablename__ = 'apps' __table_args__ = ( @@ -62,6 +67,7 @@ class App(db.Model): name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(StringUUID, nullable=True) @@ -76,7 +82,9 @@ class App(db.Model): is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @property @@ -215,7 +223,9 @@ class AppModelConfig(db.Model): provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) @@ -484,7 +494,6 @@ def tenant(self): return tenant - class Conversation(db.Model): __tablename__ = 'conversations' __table_args__ = ( @@ -512,12 +521,12 @@ class Conversation(db.Model): from_account_id = db.Column(StringUUID) read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) + dialogue_count: Mapped[int] = mapped_column(default=0) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', - passive_deletes="all") + message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -617,6 +626,15 @@ def from_end_user_session_id(self): return None + @property + def from_account_name(self): + if self.from_account_id: + account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + if account: + return account.name + + return None + @property def in_debug_mode(self): return self.override_model_configs is not None @@ -1086,6 +1104,7 @@ class Site(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) description = db.Column(db.Text) @@ -1100,7 +1119,9 @@ class Site(db.Model): customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) code = db.Column(db.String(255)) @@ -1116,7 +1137,7 @@ def generate_code(n): @property def app_base_url(self): return ( - dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.host_url.rstrip('/')) + dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip('/')) class ApiToken(db.Model): @@ -1315,9 +1336,7 @@ def tool_outputs_dict(self) -> dict: } except Exception as e: if self.observation: - return { - tool: self.observation for tool in tools - } + return dict.fromkeys(tools, self.observation) class DatasetRetrieverResource(db.Model): diff --git a/api/models/provider.py b/api/models/provider.py index 4c14c33f095cee..5d92ee6eb60d18 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,8 @@ from enum import Enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class ProviderType(Enum): diff --git a/api/models/source.py b/api/models/source.py index 265e68f014c6c2..adc00028bee43b 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -3,7 +3,8 @@ from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class DataSourceOauthBinding(db.Model): diff --git a/api/models/tool.py b/api/models/tool.py index f322944f5f0a8e..79a70c6b1f2d22 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -2,7 +2,8 @@ from enum import Enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class ToolProviderName(Enum): diff --git a/api/models/tools.py b/api/models/tools.py index 49212916ec5195..069dc5bad083c8 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -6,8 +6,9 @@ from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from models import StringUUID -from models.model import Account, App, Tenant + +from .model import Account, App, Tenant +from .types import StringUUID class BuiltinToolProvider(db.Model): @@ -299,4 +300,4 @@ class ToolFile(db.Model): # mime type mimetype = db.Column(db.String(255), nullable=False) # original url - original_url = db.Column(db.String(255), nullable=True) \ No newline at end of file + original_url = db.Column(db.String(2048), nullable=True) \ No newline at end of file diff --git a/api/models/types.py b/api/models/types.py new file mode 100644 index 00000000000000..1614ec20188541 --- /dev/null +++ b/api/models/types.py @@ -0,0 +1,26 @@ +from sqlalchemy import CHAR, TypeDecorator +from sqlalchemy.dialects.postgresql import UUID + + +class StringUUID(TypeDecorator): + impl = CHAR + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == 'postgresql': + return str(value) + else: + return value.hex + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(36)) + + def process_result_value(self, value, dialect): + if value is None: + return value + return str(value) \ No newline at end of file diff --git a/api/models/web.py b/api/models/web.py index 6fd27206a972db..0e901d5f842691 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,7 +1,8 @@ from extensions.ext_database import db -from models import StringUUID -from models.model import Message + +from .model import Message +from .types import StringUUID class SavedMessage(db.Model): diff --git a/api/models/workflow.py b/api/models/workflow.py index df2269cd0fb6cc..cdd5e1992ded49 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,20 +1,21 @@ import json from collections.abc import Mapping, Sequence +from datetime import datetime from enum import Enum from typing import Any, Optional, Union +from sqlalchemy import func +from sqlalchemy.orm import Mapped + import contexts from constants import HIDDEN_VALUE -from core.app.segments import ( - SecretVariable, - Variable, - factory, -) +from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter from extensions.ext_database import db from libs import helper -from models import StringUUID -from models.account import Account + +from .account import Account +from .types import StringUUID class CreatedByRole(Enum): @@ -110,18 +111,32 @@ class Workflow(db.Model): db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - features = db.Column(db.Text) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(StringUUID) - updated_at = db.Column(db.DateTime) - _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') + id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + version: Mapped[str] = db.Column(db.String(255), nullable=False) + graph: Mapped[str] = db.Column(db.Text) + features: Mapped[str] = db.Column(db.Text) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by: Mapped[str] = db.Column(StringUUID) + updated_at: Mapped[datetime] = db.Column(db.DateTime) + _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') + _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + + def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, + features: str, created_by: str, environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable]): + self.tenant_id = tenant_id + self.app_id = app_id + self.type = type + self.version = version + self.graph = graph + self.features = features + self.created_by = created_by + self.environment_variables = environment_variables or [] + self.conversation_variables = conversation_variables or [] @property def created_by_account(self): @@ -249,9 +264,27 @@ def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: 'graph': self.graph_dict, 'features': self.features_dict, 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], + 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], } return result + @property + def conversation_variables(self) -> Sequence[Variable]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._conversation_variables is None: + self._conversation_variables = '{}' + + variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] + return results + + @conversation_variables.setter + def conversation_variables(self, value: Sequence[Variable]) -> None: + self._conversation_variables = json.dumps( + {var.name: var.model_dump() for var in value}, + ensure_ascii=False, + ) + class WorkflowRunTriggeredFrom(Enum): """ @@ -702,3 +735,34 @@ def created_by_end_user(self): created_by_role = CreatedByRole.value_of(self.created_by_role) return db.session.get(EndUser, self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + + +class ConversationVariable(db.Model): + __tablename__ = 'workflow_conversation_variables' + + id: Mapped[str] = db.Column(StringUUID, primary_key=True) + conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) + data = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + + def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: + self.id = id + self.app_id = app_id + self.conversation_id = conversation_id + self.data = data + + @classmethod + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + obj = cls( + id=variable.id, + app_id=app_id, + conversation_id=conversation_id, + data=variable.model_dump_json(), + ) + return obj + + def to_variable(self) -> Variable: + mapping = json.loads(self.data) + return factory.build_variable_from_mapping(mapping) diff --git a/api/poetry.lock b/api/poetry.lock index 6e3b63ae59fba5..7d26dbdc57c7df 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,91 +1,118 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] -name = "aiohttp" -version = "3.9.5" -description = "Async http client/server framework (asyncio)" +name = "aiohappyeyeballs" +version = "2.4.0" +description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.8" files = [ - {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"}, - {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"}, - {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"}, - {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"}, - {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"}, - {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"}, - {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"}, - {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"}, - {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"}, - {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"}, - {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"}, - {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"}, - {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"}, - {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"}, - {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"}, - {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:694d828b5c41255e54bc2dddb51a9f5150b4eefa9886e38b52605a05d96566e8"}, - {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8"}, - {file = "aiohttp-3.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4558e5012ee03d2638c681e156461d37b7a113fe13970d438d95d10173d25f78"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dbc053ac75ccc63dc3a3cc547b98c7258ec35a215a92bd9f983e0aac95d3d5b"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4109adee842b90671f1b689901b948f347325045c15f46b39797ae1bf17019de"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6ea1a5b409a85477fd8e5ee6ad8f0e40bf2844c270955e09360418cfd09abac"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3c2890ca8c59ee683fd09adf32321a40fe1cf164e3387799efb2acebf090c11"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3916c8692dbd9d55c523374a3b8213e628424d19116ac4308e434dbf6d95bbdd"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8d1964eb7617907c792ca00b341b5ec3e01ae8c280825deadbbd678447b127e1"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5ab8e1f6bee051a4bf6195e38a5c13e5e161cb7bad83d8854524798bd9fcd6e"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:52c27110f3862a1afbcb2af4281fc9fdc40327fa286c4625dfee247c3ba90156"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:7f64cbd44443e80094309875d4f9c71d0401e966d191c3d469cde4642bc2e031"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b4f72fbb66279624bfe83fd5eb6aea0022dad8eec62b71e7bf63ee1caadeafe"}, - {file = "aiohttp-3.9.5-cp38-cp38-win32.whl", hash = "sha256:6380c039ec52866c06d69b5c7aad5478b24ed11696f0e72f6b807cfb261453da"}, - {file = "aiohttp-3.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:da22dab31d7180f8c3ac7c7635f3bcd53808f374f6aa333fe0b0b9e14b01f91a"}, - {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"}, - {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"}, - {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"}, - {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"}, - {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"}, - {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"}, + {file = "aiohappyeyeballs-2.4.0-py3-none-any.whl", hash = "sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd"}, + {file = "aiohappyeyeballs-2.4.0.tar.gz", hash = "sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2"}, ] -[package.dependencies] +[[package]] +name = "aiohttp" +version = "3.10.5" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, + {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, + {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, + {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, + {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, + {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, + {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, + {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, + {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, + {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, + {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, + {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, + {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, + {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, +] + +[package.dependencies] +aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" @@ -94,7 +121,7 @@ multidict = ">=4.5,<7.0" yarl = ">=1.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns", "brotlicffi"] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] [[package]] name = "aiohttp-retry" @@ -145,16 +172,16 @@ tz = ["backports.zoneinfo"] [[package]] name = "alibabacloud-credentials" -version = "0.3.4" +version = "0.3.5" description = "The alibabacloud credentials module of alibabaCloud Python SDK." optional = false python-versions = ">=3.6" files = [ - {file = "alibabacloud_credentials-0.3.4.tar.gz", hash = "sha256:c15a34fe782c318d4cf24cb041a0385ac4ccd2548e524e5d7fe1cff56a9a6acc"}, + {file = "alibabacloud_credentials-0.3.5.tar.gz", hash = "sha256:ad065ec95921eaf51939195485d0e5cc9e0ea050282059c7d8bf74bdb5496177"}, ] [package.dependencies] -alibabacloud-tea = "*" +alibabacloud-tea = ">=0.3.9" [[package]] name = "alibabacloud-endpoint-util" @@ -171,16 +198,16 @@ alibabacloud-tea = ">=0.0.1" [[package]] name = "alibabacloud-gateway-spi" -version = "0.0.1" +version = "0.0.2" description = "Alibaba Cloud Gateway SPI SDK Library for Python" optional = false python-versions = ">=3.6" files = [ - {file = "alibabacloud_gateway_spi-0.0.1.tar.gz", hash = "sha256:1b259855708afc3c04d8711d8530c63f7645e1edc0cf97e2fd15461b08e11c30"}, + {file = "alibabacloud_gateway_spi-0.0.2.tar.gz", hash = "sha256:f932c8ba67291531dfbee6ca521dcf3523eb4ff93512bf0aaf135f2d4fc4704d"}, ] [package.dependencies] -alibabacloud_credentials = ">=0.2.0,<1.0.0" +alibabacloud_credentials = ">=0.3.4,<1.0.0" [[package]] name = "alibabacloud-gpdb20160503" @@ -294,19 +321,19 @@ alibabacloud-tea = ">=0.0.1" [[package]] name = "alibabacloud-tea-openapi" -version = "0.3.10" +version = "0.3.11" description = "Alibaba Cloud openapi SDK Library for Python" optional = false python-versions = ">=3.6" files = [ - {file = "alibabacloud_tea_openapi-0.3.10.tar.gz", hash = "sha256:46e9c54ea857346306cd5c628dc33479349b559179ed2fdb2251dbe6ec9a1cf1"}, + {file = "alibabacloud_tea_openapi-0.3.11.tar.gz", hash = "sha256:3f5cace1b1aeb8a64587574097403cfd066b86ee4c3c9abde587f9abfcad38de"}, ] [package.dependencies] alibabacloud_credentials = ">=0.3.1,<1.0.0" alibabacloud_gateway_spi = ">=0.0.1,<1.0.0" alibabacloud_openapi_util = ">=0.2.1,<1.0.0" -alibabacloud_tea_util = ">=0.3.12,<1.0.0" +alibabacloud_tea_util = ">=0.3.13,<1.0.0" alibabacloud_tea_xml = ">=0.0.2,<1.0.0" [[package]] @@ -351,13 +378,13 @@ jmespath = ">=0.9.3,<1.0.0" [[package]] name = "aliyun-python-sdk-kms" -version = "2.16.3" +version = "2.16.4" description = "The kms module of Aliyun Python sdk." optional = false python-versions = "*" files = [ - {file = "aliyun-python-sdk-kms-2.16.3.tar.gz", hash = "sha256:c31b7d24e153271a3043e801e7b6b6b3f0db47e95a83c8d10cdab8c11662fc39"}, - {file = "aliyun_python_sdk_kms-2.16.3-py2.py3-none-any.whl", hash = "sha256:8bb8c293be94e0cc9114a5286a503d2ec215eaf8a1fb51de5d6c8bcac209d4a1"}, + {file = "aliyun-python-sdk-kms-2.16.4.tar.gz", hash = "sha256:0d5bb165c07b6a972939753a128507393f48011792ee0ec4f59b6021eabd9752"}, + {file = "aliyun_python_sdk_kms-2.16.4-py2.py3-none-any.whl", hash = "sha256:6d412663ef8c35dc3bb42be6a3ee76a9bc07acdadca6dd26815131062bedf4c5"}, ] [package.dependencies] @@ -493,22 +520,22 @@ files = [ [[package]] name = "attrs" -version = "23.2.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, - {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] -tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "authlib" @@ -524,6 +551,69 @@ files = [ [package.dependencies] cryptography = "*" +[[package]] +name = "azure-ai-inference" +version = "1.0.0b3" +description = "Microsoft Azure Ai Inference Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-ai-inference-1.0.0b3.tar.gz", hash = "sha256:1e99dc74c3b335a457500311bbbadb348f54dc4c12252a93cb8ab78d6d217ff0"}, + {file = "azure_ai_inference-1.0.0b3-py3-none-any.whl", hash = "sha256:6734ca7334c809a170beb767f1f1455724ab3f006cb60045e42a833c0e764403"}, +] + +[package.dependencies] +azure-core = ">=1.30.0" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[[package]] +name = "azure-ai-ml" +version = "1.19.0" +description = "Microsoft Azure Machine Learning Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-ai-ml-1.19.0.tar.gz", hash = "sha256:94bb1afbb0497e539ae75455fc4a51b6942b5b68b3a275727ecce6ceb250eff9"}, + {file = "azure_ai_ml-1.19.0-py3-none-any.whl", hash = "sha256:f0385af06efbeae1f83113613e45343508d1288fd2f05857619e7c7d4d4f5302"}, +] + +[package.dependencies] +azure-common = ">=1.1" +azure-core = ">=1.23.0" +azure-mgmt-core = ">=1.3.0" +azure-storage-blob = ">=12.10.0" +azure-storage-file-datalake = ">=12.2.0" +azure-storage-file-share = "*" +colorama = "*" +isodate = "*" +jsonschema = ">=4.0.0" +marshmallow = ">=3.5" +msrest = ">=0.6.18" +opencensus-ext-azure = "*" +opencensus-ext-logging = "*" +pydash = ">=6.0.0" +pyjwt = "*" +pyyaml = ">=5.1.0" +strictyaml = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +designer = ["mldesigner"] +mount = ["azureml-dataprep-rslex (>=2.22.0)"] + +[[package]] +name = "azure-common" +version = "1.1.28" +description = "Microsoft Azure Client Library for Python (Common)" +optional = false +python-versions = "*" +files = [ + {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, + {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, +] + [[package]] name = "azure-core" version = "1.30.2" @@ -560,6 +650,20 @@ cryptography = ">=2.5" msal = ">=1.24.0" msal-extensions = ">=0.3.0" +[[package]] +name = "azure-mgmt-core" +version = "1.4.0" +description = "Microsoft Azure Management Core Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-mgmt-core-1.4.0.zip", hash = "sha256:d195208340094f98e5a6661b781cde6f6a051e79ce317caabd8ff97030a9b3ae"}, + {file = "azure_mgmt_core-1.4.0-py3-none-any.whl", hash = "sha256:81071675f186a585555ef01816f2774d49c1c9024cb76e5720c3c0f6b337bb7d"}, +] + +[package.dependencies] +azure-core = ">=1.26.2,<2.0.0" + [[package]] name = "azure-storage-blob" version = "12.13.0" @@ -576,6 +680,42 @@ azure-core = ">=1.23.1,<2.0.0" cryptography = ">=2.1.4" msrest = ">=0.6.21" +[[package]] +name = "azure-storage-file-datalake" +version = "12.8.0" +description = "Microsoft Azure File DataLake Storage Client Library for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "azure-storage-file-datalake-12.8.0.zip", hash = "sha256:12e6306e5efb5ca28e0ccd9fa79a2c61acd589866d6109fe5601b18509da92f4"}, + {file = "azure_storage_file_datalake-12.8.0-py3-none-any.whl", hash = "sha256:b6cf5733fe794bf3c866efbe3ce1941409e35b6b125028ac558b436bf90f2de7"}, +] + +[package.dependencies] +azure-core = ">=1.23.1,<2.0.0" +azure-storage-blob = ">=12.13.0,<13.0.0" +msrest = ">=0.6.21" + +[[package]] +name = "azure-storage-file-share" +version = "12.17.0" +description = "Microsoft Azure Azure File Share Storage Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-storage-file-share-12.17.0.tar.gz", hash = "sha256:f7b2c6cfc1b7cb80097a53b1ed2efa9e545b49a291430d369cdb49fafbc841d6"}, + {file = "azure_storage_file_share-12.17.0-py3-none-any.whl", hash = "sha256:c4652759a9d529bf08881bb53275bf38774bb643746b849d27c47118f9cf923d"}, +] + +[package.dependencies] +azure-core = ">=1.28.0" +cryptography = ">=2.1.4" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["azure-core[aio] (>=1.28.0)"] + [[package]] name = "backoff" version = "2.2.1" @@ -669,17 +809,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.136" +version = "1.34.148" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.136-py3-none-any.whl", hash = "sha256:d41037e2c680ab8d6c61a0a4ee6bf1fdd9e857f43996672830a95d62d6f6fa79"}, - {file = "boto3-1.34.136.tar.gz", hash = "sha256:0314e6598f59ee0f34eb4e6d1a0f69fa65c146d2b88a6e837a527a9956ec2731"}, + {file = "boto3-1.34.148-py3-none-any.whl", hash = "sha256:d63d36e5a34533ba69188d56f96da132730d5e9932c4e11c02d79319cd1afcec"}, + {file = "boto3-1.34.148.tar.gz", hash = "sha256:2058397f0a92c301e3116e9e65fbbc70ea49270c250882d65043d19b7c6e2d17"}, ] [package.dependencies] -botocore = ">=1.34.136,<1.35.0" +botocore = ">=1.34.148,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -688,13 +828,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.147" +version = "1.34.162" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.147-py3-none-any.whl", hash = "sha256:be94a2f4874b1d1705cae2bd512c475047497379651678593acb6c61c50d91de"}, - {file = "botocore-1.34.147.tar.gz", hash = "sha256:2e8f000b77e4ca345146cb2edab6403769a517b564f627bb084ab335417f3dbe"}, + {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, + {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, ] [package.dependencies] @@ -703,7 +843,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.20.11)"] +crt = ["awscrt (==0.21.2)"] [[package]] name = "bottleneck" @@ -1011,63 +1151,78 @@ files = [ [[package]] name = "cffi" -version = "1.16.0" +version = "1.17.0" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" files = [ - {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, - {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"}, - {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"}, - {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"}, - {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"}, - {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"}, - {file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"}, - {file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"}, - {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"}, - {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"}, - {file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"}, - {file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"}, - {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"}, - {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"}, - {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"}, - {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"}, - {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"}, - {file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"}, - {file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"}, - {file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"}, - {file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"}, - {file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"}, - {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"}, - {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"}, - {file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"}, - {file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"}, - {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, + {file = "cffi-1.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f9338cc05451f1942d0d8203ec2c346c830f8e86469903d5126c1f0a13a2bcbb"}, + {file = "cffi-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0ce71725cacc9ebf839630772b07eeec220cbb5f03be1399e0457a1464f8e1a"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c815270206f983309915a6844fe994b2fa47e5d05c4c4cef267c3b30e34dbe42"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6bdcd415ba87846fd317bee0774e412e8792832e7805938987e4ede1d13046d"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a98748ed1a1df4ee1d6f927e151ed6c1a09d5ec21684de879c7ea6aa96f58f2"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0a048d4f6630113e54bb4b77e315e1ba32a5a31512c31a273807d0027a7e69ab"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24aa705a5f5bd3a8bcfa4d123f03413de5d86e497435693b638cbffb7d5d8a1b"}, + {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:856bf0924d24e7f93b8aee12a3a1095c34085600aa805693fb7f5d1962393206"}, + {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:4304d4416ff032ed50ad6bb87416d802e67139e31c0bde4628f36a47a3164bfa"}, + {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:331ad15c39c9fe9186ceaf87203a9ecf5ae0ba2538c9e898e3a6967e8ad3db6f"}, + {file = "cffi-1.17.0-cp310-cp310-win32.whl", hash = "sha256:669b29a9eca6146465cc574659058ed949748f0809a2582d1f1a324eb91054dc"}, + {file = "cffi-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:48b389b1fd5144603d61d752afd7167dfd205973a43151ae5045b35793232aa2"}, + {file = "cffi-1.17.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c5d97162c196ce54af6700949ddf9409e9833ef1003b4741c2b39ef46f1d9720"}, + {file = "cffi-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ba5c243f4004c750836f81606a9fcb7841f8874ad8f3bf204ff5e56332b72b9"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb9333f58fc3a2296fb1d54576138d4cf5d496a2cc118422bd77835e6ae0b9cb"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:435a22d00ec7d7ea533db494da8581b05977f9c37338c80bc86314bec2619424"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1df34588123fcc88c872f5acb6f74ae59e9d182a2707097f9e28275ec26a12d"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df8bb0010fdd0a743b7542589223a2816bdde4d94bb5ad67884348fa2c1c67e8"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8b5b9712783415695663bd463990e2f00c6750562e6ad1d28e072a611c5f2a6"}, + {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ffef8fd58a36fb5f1196919638f73dd3ae0db1a878982b27a9a5a176ede4ba91"}, + {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e67d26532bfd8b7f7c05d5a766d6f437b362c1bf203a3a5ce3593a645e870b8"}, + {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45f7cd36186db767d803b1473b3c659d57a23b5fa491ad83c6d40f2af58e4dbb"}, + {file = "cffi-1.17.0-cp311-cp311-win32.whl", hash = "sha256:a9015f5b8af1bb6837a3fcb0cdf3b874fe3385ff6274e8b7925d81ccaec3c5c9"}, + {file = "cffi-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:b50aaac7d05c2c26dfd50c3321199f019ba76bb650e346a6ef3616306eed67b0"}, + {file = "cffi-1.17.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aec510255ce690d240f7cb23d7114f6b351c733a74c279a84def763660a2c3bc"}, + {file = "cffi-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2770bb0d5e3cc0e31e7318db06efcbcdb7b31bcb1a70086d3177692a02256f59"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db9a30ec064129d605d0f1aedc93e00894b9334ec74ba9c6bdd08147434b33eb"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a47eef975d2b8b721775a0fa286f50eab535b9d56c70a6e62842134cf7841195"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3e0992f23bbb0be00a921eae5363329253c3b86287db27092461c887b791e5e"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6107e445faf057c118d5050560695e46d272e5301feffda3c41849641222a828"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb862356ee9391dc5a0b3cbc00f416b48c1b9a52d252d898e5b7696a5f9fe150"}, + {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c1c13185b90bbd3f8b5963cd8ce7ad4ff441924c31e23c975cb150e27c2bf67a"}, + {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17c6d6d3260c7f2d94f657e6872591fe8733872a86ed1345bda872cfc8c74885"}, + {file = "cffi-1.17.0-cp312-cp312-win32.whl", hash = "sha256:c3b8bd3133cd50f6b637bb4322822c94c5ce4bf0d724ed5ae70afce62187c492"}, + {file = "cffi-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:dca802c8db0720ce1c49cce1149ff7b06e91ba15fa84b1d59144fef1a1bc7ac2"}, + {file = "cffi-1.17.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6ce01337d23884b21c03869d2f68c5523d43174d4fc405490eb0091057943118"}, + {file = "cffi-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cab2eba3830bf4f6d91e2d6718e0e1c14a2f5ad1af68a89d24ace0c6b17cced7"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14b9cbc8f7ac98a739558eb86fabc283d4d564dafed50216e7f7ee62d0d25377"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b00e7bcd71caa0282cbe3c90966f738e2db91e64092a877c3ff7f19a1628fdcb"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41f4915e09218744d8bae14759f983e466ab69b178de38066f7579892ff2a555"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4760a68cab57bfaa628938e9c2971137e05ce48e762a9cb53b76c9b569f1204"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:011aff3524d578a9412c8b3cfaa50f2c0bd78e03eb7af7aa5e0df59b158efb2f"}, + {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:a003ac9edc22d99ae1286b0875c460351f4e101f8c9d9d2576e78d7e048f64e0"}, + {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ef9528915df81b8f4c7612b19b8628214c65c9b7f74db2e34a646a0a2a0da2d4"}, + {file = "cffi-1.17.0-cp313-cp313-win32.whl", hash = "sha256:70d2aa9fb00cf52034feac4b913181a6e10356019b18ef89bc7c12a283bf5f5a"}, + {file = "cffi-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:b7b6ea9e36d32582cda3465f54c4b454f62f23cb083ebc7a94e2ca6ef011c3a7"}, + {file = "cffi-1.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:964823b2fc77b55355999ade496c54dde161c621cb1f6eac61dc30ed1b63cd4c"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:516a405f174fd3b88829eabfe4bb296ac602d6a0f68e0d64d5ac9456194a5b7e"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dec6b307ce928e8e112a6bb9921a1cb00a0e14979bf28b98e084a4b8a742bd9b"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4094c7b464cf0a858e75cd14b03509e84789abf7b79f8537e6a72152109c76e"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2404f3de742f47cb62d023f0ba7c5a916c9c653d5b368cc966382ae4e57da401"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa9d43b02a0c681f0bfbc12d476d47b2b2b6a3f9287f11ee42989a268a1833c"}, + {file = "cffi-1.17.0-cp38-cp38-win32.whl", hash = "sha256:0bb15e7acf8ab35ca8b24b90af52c8b391690ef5c4aec3d31f38f0d37d2cc499"}, + {file = "cffi-1.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:93a7350f6706b31f457c1457d3a3259ff9071a66f312ae64dc024f049055f72c"}, + {file = "cffi-1.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a2ddbac59dc3716bc79f27906c010406155031a1c801410f1bafff17ea304d2"}, + {file = "cffi-1.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6327b572f5770293fc062a7ec04160e89741e8552bf1c358d1a23eba68166759"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbc183e7bef690c9abe5ea67b7b60fdbca81aa8da43468287dae7b5c046107d4"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bdc0f1f610d067c70aa3737ed06e2726fd9d6f7bfee4a351f4c40b6831f4e82"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6d872186c1617d143969defeadac5a904e6e374183e07977eedef9c07c8953bf"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d46ee4764b88b91f16661a8befc6bfb24806d885e27436fdc292ed7e6f6d058"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f76a90c345796c01d85e6332e81cab6d70de83b829cf1d9762d0a3da59c7932"}, + {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0e60821d312f99d3e1569202518dddf10ae547e799d75aef3bca3a2d9e8ee693"}, + {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:eb09b82377233b902d4c3fbeeb7ad731cdab579c6c6fda1f763cd779139e47c3"}, + {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:24658baf6224d8f280e827f0a50c46ad819ec8ba380a42448e24459daf809cf4"}, + {file = "cffi-1.17.0-cp39-cp39-win32.whl", hash = "sha256:0fdacad9e0d9fc23e519efd5ea24a70348305e8d7d85ecbb1a5fa66dc834e7fb"}, + {file = "cffi-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cbc78dc018596315d4e7841c8c3a7ae31cc4d638c9b627f87d52e8abaaf2d29"}, + {file = "cffi-1.17.0.tar.gz", hash = "sha256:f3157624b7558b914cb039fd1af735e5e8049a87c817cc215109ad1c8779df76"}, ] [package.dependencies] @@ -1343,77 +1498,77 @@ testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] [[package]] name = "clickhouse-connect" -version = "0.7.16" +version = "0.7.19" description = "ClickHouse Database Core Driver for Python, Pandas, and Superset" optional = false python-versions = "~=3.8" files = [ - {file = "clickhouse-connect-0.7.16.tar.gz", hash = "sha256:253a2089efad5729903d00382f73fa8da2cbbfdb118db498cf708ee9f4a2134f"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:00413deb9e086aabf661d18ac3a3539f25eb773c3675f49353e0d7e6ef1205fc"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:faadaf206ea7753782db017daedbf592e4edc7c71cb985aad787eb9dc516bf21"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1db8f1168f33fda78adddb733913b211ddf648984d8fef8d934e30df876e5f23"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fa630bf50fb064cc53b7ea5d862066476d3c6074003f6d39d2594fb1a7abf67"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cba9547dad41b2d333458615208a3c7db6f56a63473ffea2c05c44225ffa020"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:480f7856fcf42a21f17886e0b42d70499067c865fc2a0ea7c0eb5c0bdca281a8"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b65f3eb570cbcf9fa383b4e0925d1ceb3efd3deba42a435625cad75b3a9ff7f3"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b78d3cc0fe42374bb9d5a05ba71578dc69f7e4b4c771e86dcf292ae0412265cc"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-win32.whl", hash = "sha256:1cb76b26fcde1ba6a8ae68e1db1f9e42d458879a0d4d2c9843cc998f42f445ac"}, - {file = "clickhouse_connect-0.7.16-cp310-cp310-win_amd64.whl", hash = "sha256:9298b344168271e952ea41021963ca1b81b9b3c38be8b036cb64a2556edbb4b7"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ae39a765735cc6e786e5f9a0dba799e7f8ee0bbd5dfc5d5ff755dfa9dd13855"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f32546f65dd234a49310cda454713a5f7fbc8ba978744e070355c7ea8819a5a"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20865c81a5b378625a528ac8960e08cdca316147f87fad6deb9f16c0d5e5f62f"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609c076261d779703bf29e7a27dafc8283153403ceab1ec23d50eb2acabc4b9d"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e07862e75ac7419c5671384055f11ca5e76dc2c0be4a6f3aed7bf419997184bc"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d5db7da6f20b9a49b288063de9b3224a56634f8cb94d19d435af518ed81872c3"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:955c567ede68a10325045bb2adf1314ff569dfb7e52f6074c18182f3803279f6"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:df517bfe23d85f5aeeb17b262c06d0a5c24e0baea09688a96d02dc8589ef8b07"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-win32.whl", hash = "sha256:7f2c6132fc90df6a8318abb9f257c2b777404908b7d168ac08235d516f65a663"}, - {file = "clickhouse_connect-0.7.16-cp311-cp311-win_amd64.whl", hash = "sha256:ca1dba53da86691a11671d846988dc4f6ad02a66f5a0df9a87a46dc4ec9bb0a1"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f8f7260073b6ee63e19d442ebb6954bc7741a5ce4ed563eb8074c8c6a0158eca"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b3dd93ada1099cb6df244d79973c811e90a4590685e78e60e8846914b3c261e"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d3c3458bce25fe9c10e1dbf82dbeeeb2f04e382130f9811cc3bedf44c2028ca"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcc302390b4ea975efd8d2ca53d295d40dc766179dd5e9fc158e808f01d9280d"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a94f6d095d7174c55825e0b5c04b77897a1b2a8a8bbb38f3f773fd3113a7be27"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6b7e2572993ef2e1dee5012875a7a2d08cede319e32ccdd2db90ed26a0d0c037"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e9c35ee425309ed8ef63bae31e1d3c5f35706fa27ae2836e61e7cb9bbe7f00cb"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:eb0471d5a32d07eaa37772871ee9e6b5eb37ab907c3c154833824ed68ee4795b"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-win32.whl", hash = "sha256:b531ee18b4ce16f1d2b8f6249859cbd600f7e0f312f80dda8deb969791a90f17"}, - {file = "clickhouse_connect-0.7.16-cp312-cp312-win_amd64.whl", hash = "sha256:38392308344770864843f7f8b914799684c13ce4b272d5a3a55e5512ff8a3ae0"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:052ca80d66e49c94d103c9842d2a5b0ebf4610981b79164660ef6b1bdc4b5e85"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b496059d145c68e956aa10cd04e5c7cb4e97312eb3f7829cec8f4f7024f8ced6"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de1e423fc9c415b9fdcbb6f23eccae981e3f0f0cf142e518efec709bda7c1394"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:555c64719cbc72675d58ea6dfc144fa8064ea1d673a54afd2d54e34c58f17c6b"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0c3c063ab23df8f71a36505880bf5de6c18aee246938d787447e52b4d9d5531"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5ed62e08cfe445d0430b91c26fb276e2a5175e456e9786594fb6e67c9ebd8c6c"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d9eb056bd14ca3c1d7e3edd7ca79ea970d45e5e536930dbb6179aeb965d5bc3d"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:54e0a03b685ee6c138954846dafb6ec0e0baf8257f2587c61e34c017f3dc9d63"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-win32.whl", hash = "sha256:d8402c3145387726bd19f916ca2890576be70c4493f030c068f6f03a75addff7"}, - {file = "clickhouse_connect-0.7.16-cp38-cp38-win_amd64.whl", hash = "sha256:70e376d2ebc0f092fae35f7b50ff7296ee8ffd2dda3536238f6c39a5c949d115"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cee4f91ad22401c3b96f5df3f3149ef2894e7c2d00b5abd9da80119e7b6592f7"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a3009145f35e9ac2535dbd8fdbdc218abfe0971c9bc9b730eb5c3f6c40faeb5f"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d0ef9f877ffbcb0f526ce9c35c657fc54930d043e45c077d9d886c0f1add727"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acc437b3ff2f7991b209b861a89c003ac1971c890775190178438780e967a9d3"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ed836dcee4ac097bd83714abe0af987b1ef767675a555e7643d793164c3f1cc"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4c4e0d173239c0b4594c8703fae5c8ba3241c4e0763a8cf436b94564692671f9"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a17a348dd8c00df343a01128497e8c3a6ae431f13c7a88e363ac12c035316ce0"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:805ae7ad39c043af13e2b5af45abb70330f0907749dc87ad4a2481a4ac209cc6"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-win32.whl", hash = "sha256:38fc6ca1bd73cf4dcebd22fbb8dceda267908ff674fc57fbc23c3b5df9c21ac1"}, - {file = "clickhouse_connect-0.7.16-cp39-cp39-win_amd64.whl", hash = "sha256:3dc67e99e40b5a8bc493a21016830b0f3800006a6038c1fd881f7cae6246cc44"}, - {file = "clickhouse_connect-0.7.16-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b7f526fef71bd5265f47915340a6369a5b5685278b72b5aff281cc521a8ec376"}, - {file = "clickhouse_connect-0.7.16-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e00f87ba68bbc63dd32d7a304fd629b759f24b09f88fbc2bac0a9ed1fe7b2938"}, - {file = "clickhouse_connect-0.7.16-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:09c84f3b64d6bebedcfbbd19e8369b3df2cb7d313afb2a0d64a3e151d344c1c1"}, - {file = "clickhouse_connect-0.7.16-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18d104ab78edee26e8cef056e2db83f03e1da918df0946e1ef1ad9a27a024dd0"}, - {file = "clickhouse_connect-0.7.16-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cc1ad53e282ff5b4288fdfcf6df72cda542d9d997de5889d66a1f8e2b9f477f0"}, - {file = "clickhouse_connect-0.7.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fddc99322054f5d3df8715ab3724bd36ac636f8ceaed4f5f3f60d377abd22d22"}, - {file = "clickhouse_connect-0.7.16-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:765a2de98197d1b4f6424611ceaca2ae896a1d7093b943403973888cb7c144e6"}, - {file = "clickhouse_connect-0.7.16-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1540e0a93e5f2147400f644606a399c91705066f05d5a91429616ee9812f4521"}, - {file = "clickhouse_connect-0.7.16-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba928c4178b0d4a513e1b0ad32a464ab56cb1bc27736a7f41b32e4eb70eb08d6"}, - {file = "clickhouse_connect-0.7.16-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a17ffc22e905081f002173b30959089de6987fd40c87e7794da9d978d723e610"}, - {file = "clickhouse_connect-0.7.16-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:26df09787232b495285d8358db145b9770f472e2e30147912634c5b56392e73f"}, - {file = "clickhouse_connect-0.7.16-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2a3ce33241441dc7c718c19e31645323e6c5da793d46bbb670fd4e8557b8605"}, - {file = "clickhouse_connect-0.7.16-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29f9dc9cc1f4ec4a333bf119abb5cee13563e89bc990d4d77b8f43cf630e9fb1"}, - {file = "clickhouse_connect-0.7.16-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a16a7ada11996a6fa0959c83e2e46ff32773e57eca40eff86176fd62a30054ca"}, - {file = "clickhouse_connect-0.7.16-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ead20e1d4f3c5493dd075b7dc81b5d21be4b876aca6952e1c155824876c621f3"}, + {file = "clickhouse-connect-0.7.19.tar.gz", hash = "sha256:ce8f21f035781c5ef6ff57dc162e8150779c009b59f14030ba61f8c9c10c06d0"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6ac74eb9e8d6331bae0303d0fc6bdc2125aa4c421ef646348b588760b38c29e9"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:300f3dea7dd48b2798533ed2486e4b0c3bb03c8d9df9aed3fac44161b92a30f9"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c72629f519105e21600680c791459d729889a290440bbdc61e43cd5eb61d928"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ece0fb202cd9267b3872210e8e0974e4c33c8f91ca9f1c4d92edea997189c72"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6e5adf0359043d4d21c9a668cc1b6323a1159b3e1a77aea6f82ce528b5e4c5b"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:63432180179e90f6f3c18861216f902d1693979e3c26a7f9ef9912c92ce00d14"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:754b9c58b032835caaa9177b69059dc88307485d2cf6d0d545b3dedb13cb512a"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:24e2694e89d12bba405a14b84c36318620dc50f90adbc93182418742d8f6d73f"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-win32.whl", hash = "sha256:52929826b39b5b0f90f423b7a035930b8894b508768e620a5086248bcbad3707"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-win_amd64.whl", hash = "sha256:5c301284c87d132963388b6e8e4a690c0776d25acc8657366eccab485e53738f"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee47af8926a7ec3a970e0ebf29a82cbbe3b1b7eae43336a81b3a0ca18091de5f"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce429233b2d21a8a149c8cd836a2555393cbcf23d61233520db332942ffb8964"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:617c04f5c46eed3344a7861cd96fb05293e70d3b40d21541b1e459e7574efa96"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08e33b8cc2dc1873edc5ee4088d4fc3c0dbb69b00e057547bcdc7e9680b43e5"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:921886b887f762e5cc3eef57ef784d419a3f66df85fd86fa2e7fbbf464c4c54a"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6ad0cf8552a9e985cfa6524b674ae7c8f5ba51df5bd3ecddbd86c82cdbef41a7"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:70f838ef0861cdf0e2e198171a1f3fd2ee05cf58e93495eeb9b17dfafb278186"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c5f0d207cb0dcc1adb28ced63f872d080924b7562b263a9d54d4693b670eb066"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-win32.whl", hash = "sha256:8c96c4c242b98fcf8005e678a26dbd4361748721b6fa158c1fe84ad15c7edbbe"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-win_amd64.whl", hash = "sha256:bda092bab224875ed7c7683707d63f8a2322df654c4716e6611893a18d83e908"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8f170d08166438d29f0dcfc8a91b672c783dc751945559e65eefff55096f9274"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26b80cb8f66bde9149a9a2180e2cc4895c1b7d34f9dceba81630a9b9a9ae66b2"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba80e3598acf916c4d1b2515671f65d9efee612a783c17c56a5a646f4db59b9"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d38c30bd847af0ce7ff738152478f913854db356af4d5824096394d0eab873d"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d41d4b159071c0e4f607563932d4fa5c2a8fc27d3ba1200d0929b361e5191864"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3682c2426f5dbda574611210e3c7c951b9557293a49eb60a7438552435873889"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6d492064dca278eb61be3a2d70a5f082e2ebc8ceebd4f33752ae234116192020"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:62612da163b934c1ff35df6155a47cf17ac0e2d2f9f0f8f913641e5c02cdf39f"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-win32.whl", hash = "sha256:196e48c977affc045794ec7281b4d711e169def00535ecab5f9fdeb8c177f149"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:85a016eebff440b76b90a4725bb1804ddc59e42bba77d21c2a2ec4ac1df9e28d"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f059d3e39be1bafbf3cf0e12ed19b3cbf30b468a4840ab85166fd023ce8c3a17"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ed54ba0998fd6899fcc967af2b452da28bd06de22e7ebf01f15acbfd547eac"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e4b4d786572cb695a087a71cfdc53999f76b7f420f2580c9cffa8cc51442058"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3710ca989ceae03d5ae56a436b4fe246094dbc17a2946ff318cb460f31b69450"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d104f25a054cb663495a51ccb26ea11bcdc53e9b54c6d47a914ee6fba7523e62"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ee23b80ee4c5b05861582dd4cd11f0ca0d215a899e9ba299a6ec6e9196943b1b"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:942ec21211d369068ab0ac082312d4df53c638bfc41545d02c41a9055e212df8"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-win32.whl", hash = "sha256:cb8f0a59d1521a6b30afece7c000f6da2cd9f22092e90981aa83342032e5df99"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-win_amd64.whl", hash = "sha256:98d5779dba942459d5dc6aa083e3a8a83e1cf6191eaa883832118ad7a7e69c87"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9f57aaa32d90f3bd18aa243342b3e75f062dc56a7f988012a22f65fb7946e81d"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5fb25143e4446d3a73fdc1b7d976a0805f763c37bf8f9b2d612a74f65d647830"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b4e19c9952b7b9fe24a99cca0b36a37e17e2a0e59b14457a2ce8868aa32e30e"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9876509aa25804f1377cb1b54dd55c1f5f37a9fbc42fa0c4ac8ac51b38db5926"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04cfb1dae8fb93117211cfe4e04412b075e47580391f9eee9a77032d8e7d46f4"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b04f7c57f61b5dfdbf49d4b5e4fa5e91ce86bee09bb389b641268afa8f511ab4"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e5b563f32dcc9cb6ff1f6ed238e83c3e80eb15814b1ea130817c004c241a3c2e"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6018675a231130bd03a7b39a3e875e683286d98115085bfa3ac0918f555f4bfe"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-win32.whl", hash = "sha256:5cb67ae3309396033b825626d60fe2cd789c1d2a183faabef8ffdbbef153d7fb"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-win_amd64.whl", hash = "sha256:fd225af60478c068cde0952e8df8f731f24c828b75cc1a2e61c21057ff546ecd"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f31898e0281f820e35710b5c4ad1d40a6c01ffae5278afaef4a16877ac8cbfb"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51c911b0b8281ab4a909320f41dd9c0662796bec157c8f2704de702c552104db"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1088da11789c519f9bb8927a14b16892e3c65e2893abe2680eae68bf6c63835"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03953942cc073078b40619a735ebeaed9bf98efc71c6f43ce92a38540b1308ce"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4ac0602fa305d097a0cd40cebbe10a808f6478c9f303d57a48a3a0ad09659544"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4fdefe9eb2d38063835f8f1f326d666c3f61de9d6c3a1607202012c386ca7631"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff6469822fe8d83f272ffbb3fb99dcc614e20b1d5cddd559505029052eff36e7"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46298e23f7e7829f0aa880a99837a82390c1371a643b21f8feb77702707b9eaa"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c6409390b13e09c19435ff65e2ebfcf01f9b2382e4b946191979a5d54ef8625c"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cd7e7097b30b70eb695b7b3b6c79ba943548c053cc465fa74efa67a2354f6acd"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:15e080aead66e43c1f214b3e76ab26e3f342a4a4f50e3bbc3118bdd013d12e5f"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:194d2a32ba1b370cb5ac375dd4153871bb0394ff040344d8f449cb36ea951a96"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ac93aafd6a542fdcad4a2b6778575eab6dbdbf8806e86d92e1c1aa00d91cfee"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b208dd3e29db7154b02652c26157a1903bea03d27867ca5b749edc2285c62161"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9724fdf3563b2335791443cb9e2114be7f77c20c8c4bbfb3571a3020606f0773"}, ] [package.dependencies] @@ -1433,115 +1588,115 @@ tzlocal = ["tzlocal (>=4.0)"] [[package]] name = "clickhouse-driver" -version = "0.2.8" +version = "0.2.9" description = "Python driver with native interface for ClickHouse" optional = false python-versions = "<4,>=3.7" files = [ - {file = "clickhouse-driver-0.2.8.tar.gz", hash = "sha256:844b3080e558acbacd42ee569ec83ca7aaa3728f7077b9314c8d09aaa393d752"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3a3a708e020ed2df59e424631f1822ffef4353912fcee143f3b7fc34e866621d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d258d3c3ac0f03527e295eeaf3cebb0a976bc643f6817ccd1d0d71ce970641b4"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f63fb64a55dea29ed6a7d1d6805ebc95c37108c8a36677bc045d904ad600828"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b16d5dbd53fe32a99d3c4ab6c478c8aa9ae02aec5a2bd2f24180b0b4c03e1a5"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad2e1850ce91301ae203bc555fb83272dfebb09ad4df99db38c608d45fc22fa4"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ae9239f61a18050164185ec0a3e92469d084377a66ae033cc6b4efa15922867"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8f222f2577bf304e86eec73dbca9c19d7daa6abcafc0bef68bbf31dd461890b"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:451ac3de1191531d030751b05f122219b93b3c509e781fad81c2c91f0e9256b6"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5a2c4fea88e91f1d5217b760ffea84631e647d8db2265b821cbe7b0e015c7807"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:19825a3044c48ab65dc6659eb9763e2f0821887bdd9ee14a2f9ae8c539281ebf"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ae13044a10015225297868658a6f1843c2e34b9fcaa6268880e25c4fca9f3c4d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:548a77efb86012800e76db6d45b3dcffea9a1a26fa3d5fd42021298f0b9a6f16"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-win32.whl", hash = "sha256:ebe4328eaaf937365114b5bab5626600ee57e57d4d099ba2ddbae48c2493f73d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:7beaeb4d7e6c3aba7e02375eeca85b20cc8e54dc31fcdb25d3c4308f2cd9465f"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8e06ef6bb701c8e42a9c686d77ad30805cf431bb79fa8fe0f4d3dee819e9a12c"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4afbcfa557419ed1783ecde3abbee1134e09b26c3ab0ada5b2118ae587357c2b"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85f628b4bf6db0fe8fe13da8576a9b95c23b463dff59f4c7aa58cedf529d7d97"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:036f4b3283796ca51610385c7b24bdac1bb873f8a2e97a179f66544594aa9840"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c8916d3d324ce8fd31f8dedd293dc2c29204b94785a5398d1ec1e7ea4e16a26"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30bee7cddd85c04ec49c753b53580364d907cc05c44daafe31b924a352e5e525"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03c8a844f6b128348d099dc5d75fad70f4e85802d1649c1b835916ac94ae750a"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:33965329393fd7740b445758787ddacdf70f35fa3411f98a1a86918fff679a46"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8cf85a7ebb0a56182c5b659602e20bae6b36c48a0edf518a6e6f56042d3fcee0"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c10fd1f921ff82638cb9513b9b4acfb575b421c44ef6bf6cf57ee3c487b9d538"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a30d49bb6c34e3f5fe42e43dd6a7da0523ddfd05834ef02bd70b9363ea7de7e"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ea32c377a347b0801fc7f2b242f2ec7d78df58047097352672d0de5fbfa9e390"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-win32.whl", hash = "sha256:2a85529d1c0c3f2eedf7a4f736d0efc6e6c8032ac90ca5a63f7a067db58384fe"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:1f438f83a7473ce7fe9c16cda8750e2fdda1b09fb87f0ec6b87a2b89acb13f24"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b71bbef6ee08252cee0593329c8ca8e623547627807d38195331f476eaf8136"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f30b3dd388f28eb4052851effe671354db55aea87de748aaf607e7048f72413e"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3bb27ce7ca61089c04dc04dbf207c9165d62a85eb9c99d1451fd686b6b773f9"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59c04ec0b45602b6a63e0779ca7c3d3614be4710ec5ac7214da1b157d43527c5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a61b14244993c7e0f312983455b7851576a85ab5a9fcc6374e75d2680a985e76"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99a1b0b7759ccd1bf44c65210543c228ba704e3153014fd3aabfe56a227b1a5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f14d860088ab2c7eeb3782c9490ad3f6bf6b1e9235e9db9c3b0079cd4751ffa"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:303887a14a71faddcdee150bc8cde498c25c446b0a72ae586bd67d0c366dbff5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:359814e4f989c138bfb83e3c81f8f88c8449721dcf32cb8cc25fdb86f4b53c99"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:42de61b4cf9053698b14dbe29e1e3d78cb0a7aaef874fd854df390de5c9cc1f1"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:3bf3089f220480e5a69cbec79f3b65c23afb5c2836e7285234140e5f237f2768"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:41daa4ae5ada22f10c758b0b3b477a51f5df56eef8569cff8e2275de6d9b1b96"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-win32.whl", hash = "sha256:03ea71c7167c6c38c3ba2bbed43615ce0c41ebf3bfa28d96ffcd93cd1cdd07d8"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:76985286e10adb2115da116ae25647319bc485ad9e327cbc27296ccf0b052180"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:271529124914c439a5bbcf8a90e3101311d60c1813e03c0467e01fbabef489ee"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8f499746bc027c6d05de09efa7b2e4f2241f66c1ac2d6b7748f90709b00e10"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f29f256520bb718c532e7fcd85250d4001f49acbaa9e6896bdf4a70d5557e2ef"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:104d062bdf7eab74e92efcbf72088b3241365242b4f119b3fe91057c4d80825c"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee34ed08592a6eff5e176f42897c6ab4dfd8c07df16e9f392e18f1f2ee3fe3ca"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5be9a8d89de881d5ea9d46f9d293caa72dbc7f40b105374cafd88f52b2099ea"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c57efc768fa87e83d6778e7bbd180dd1ff5d647044983ec7d238a8577bd25fa5"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e1a003475f2d54e9fea8de86b57bc26b409c9efea3d298409ab831f194d62c3b"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:fba71cf41934a23156290a70ef794a5dadc642b21cc25eb13e1f99f2512c8594"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7289b0e9d1019fed418c577963edd66770222554d1da0c491ca436593667256e"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:16e810cc9be18fdada545b9a521054214dd607bb7aa2f280ca488da23a077e48"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:ed4a6590015f18f414250149255dc2ae81ae956b6e670b290d52c2ecb61ed517"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9d454f16ccf1b2185cc630f6fb2160b1abde27759c4e94c42e30b9ea911d58f0"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2e487d49c24448873a6802c34aa21858b9e3fb4a2605268a980a5c02b54a6bae"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e877de75b97ddb11a027a7499171ea0aa9cad569b18fce53c9d508353000cfae"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c60dcefddf6e2c65c92b7e6096c222ff6ed73b01b6c5712f9ce8a23f2ec80f1a"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:422cbbabfad3f9b533d9f517f6f4e174111a613cba878402f7ef632b0eadec3a"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ff8a8e25ff6051ff3d0528dbe36305b0140075d2fa49432149ee2a7841f23ed"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19c7a5960d4f7f9a8f9a560ae05020ff5afe874b565cce06510586a0096bb626"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5b3333257b46f307b713ba507e4bf11b7531ba3765a4150924532298d645ffd"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bbc2252a697c674e1b8b6123cf205d2b15979eddf74e7ada0e62a0ecc81a75c3"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:af7f1a9a99dafb0f2a91d1a2d4a3e37f86076147d59abbe69b28d39308fe20fb"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:580c34cc505c492a8abeacbd863ce46158643bece914d8fe2fadea0e94c4e0c1"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5b905eaa6fd3b453299f946a2c8f4a6392f379597e51e46297c6a37699226cda"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6e2b5891c52841aedf803b8054085eb8a611ad4bf57916787a1a9aabf618fb77"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-win32.whl", hash = "sha256:b58a5612db8b3577dc2ae6fda4c783d61c2376396bb364545530aa6a767f166d"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:96b0424bb5dd698c10b899091562a78f4933a9a039409f310fb74db405d73854"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22cbed52daa584ca9a93efd772ee5c8c1f68ceaaeb21673985004ec2fd411c49"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e36156fe8a355fc830cc0ea1267c804c631c9dbd9b6accdca868a426213e5929"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c1341325f4180e1318d0d2cf0b268008ea250715c6f30a5ccce586860c000b5"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cb52161276f7d77d4af09f1aab97a16edf86014a89e3d9923f0a6b8fdaa12438"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d1ccd47040c0a8753684a20a0f83b8a0820386889fdf460a3248e0eed142032"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcda48e938d011e5f4dcebf965e6ec19e020e8efa207b98eeb99c12fa873236d"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2252ab3f8b3bbd705e1d7dc80395c7bea14f5ae51a268fc7be5328da77c0e200"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e1b9ef3fa0cc6c9de77daa74a2f183186d0b5556c4f6870fc966a41fde6cae2b"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d0afa3c68fed6b5e6f23eb3f053d3aba86d09dbbc7706a0120ab5595d5c37003"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:102027bb687ff7a978f7110348f39f0dce450ab334787edbc64b8a9927238e32"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9fc1ae52a171ded7d9f1f971b9b5bb0ce4d0490a54e102f3717cea51011d0308"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5a62c691be83b1da72ff3455790b50b0f894b7932ac962a8133f3f9c04c943b3"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-win32.whl", hash = "sha256:8b5068cef07cfba5be25a9a461c010ce7a0fe2de5b0b0262c6030684f43fa7f5"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:cd71965d00b0f3ba992652d577b1d46b87100a67b3e0dc5c191c88092e484c81"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4db0812c43f67e7b1805c05e2bc08f7d670ddfd8d8c671c9b47cdb52f4f74129"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56622ffefe94a82d9a30747e3486819104d1310d7a94f0e37da461d7112e9864"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c47c8ed61b2f35bb29d991f66d6e03d5cc786def56533480331b2a584854dd5"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dec001a1a49b993522dd134d2fca161352f139d42edcda0e983b8ea8f5023cda"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c03bd486540a6c03aa5a164b7ec6c50980df9642ab1ce22cb70327e4090bdc60"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c059c3da454f0cc0a6f056b542a0c1784cd0398613d25326b11fd1c6f9f7e8d2"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc7f9677c637b710046ec6c6c0cab25b4c4ff21620e44f462041d7455e9e8d13"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba3f6b8fdd7a2e6a831ebbcaaf346f7c8c5eb5085a350c9d4d1ce7053a050b70"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:20c2db3ae29950c80837d270b5ab63c74597afce226b474930060cac7969287b"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b7767019a301dad314e7b515046535a45eda84bd9c29590bc3e99b1c334f69e7"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ba8b8b80fa8850546aa40acc952835b1f149af17182cdf3db4f2133b2a241fe8"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:924f11e87e3dcbbc1c9e8158af9917f182cd5e96d37385485d6268f59b564142"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c39e1477ad310a4d276db17c1e1cf6fb059c29eb8d21351afefd5a22de381c6"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e950b9a63af5fa233e3da0e57a7ebd85d4b319e65eef5f9daac84532836f4123"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:0698dc57373b2f42f3a95bd419d9fa07f2d02150f13a0db2909a2651208262b9"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e0694ca2fb459c23e44036d975fe89544a7c9918618b5d8bda9a8aa2d24e5c37"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62620348aeae5a905ccb8f7e6bff8d76aae9a95d81aa8c8f6fce0f2af7e104b8"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66276fd5092cccdd6f3123df4357a068fb1972b7e2622fab6f235948c50b6eed"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f86fe87327662b597824d0d7505cc600b0919473b22bbbd178a1a4d4e29283e1"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:54b9c6ff0aaabdcf7e80a6d9432459611b3413d6a66bec41cbcdad7212721cc7"}, + {file = "clickhouse-driver-0.2.9.tar.gz", hash = "sha256:050ea4870ead993910b39e7fae965dc1c347b2e8191dcd977cd4b385f9e19f87"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6ce04e9d0d0f39561f312d1ac1a8147bc9206e4267e1a23e20e0423ebac95534"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7ae5c8931bf290b9d85582e7955b9aad7f19ff9954e48caa4f9a180ea4d01078"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e51792f3bd12c32cb15a907f12de3c9d264843f0bb33dce400e3966c9f09a3f"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:42fc546c31e4a04c97b749769335a679c9044dc693fa7a93e38c97fd6727173d"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a383a403d185185c64e49edd6a19b2ec973c5adcb8ebff7ed2fc539a2cc65a5"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f05321a97e816afc75b3e4f9eda989848fecf14ecf1a91d0f22c04258123d1f7"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be47e793846aac28442b6b1c6554e0731b848a5a7759a54aa2489997354efe4a"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:780e42a215d1ae2f6d695d74dd6f087781fb2fa51c508b58f79e68c24c5364e0"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9e28f1fe850675e173db586e9f1ac790e8f7edd507a4227cd54cd7445f8e75b6"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:125aae7f1308d3083dadbb3c78f828ae492e060f13e4007a0cf53a8169ed7b39"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:2f3c4fbb61e75c62a1ab93a1070d362de4cb5682f82833b2c12deccb3bae888d"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dc03196a84e32d23b88b665be69afae98f57426f5fdf203e16715b756757961"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-win32.whl", hash = "sha256:25695d78a1d7ad6e221e800612eac08559f6182bf6dee0a220d08de7b612d993"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-win_amd64.whl", hash = "sha256:367acac95398d721a0a2a6cf87e93638c5588b79498a9848676ce7f182540a6c"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5a7353a7a08eee3aa0001d8a5d771cb1f37e2acae1b48178002431f23892121a"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6af1c6cbc3481205503ab72a34aa76d6519249c904aa3f7a84b31e7b435555be"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48033803abd1100bfff6b9a1769d831b672cd3cda5147e0323b956fd1416d38d"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f202a58a540c85e47c31dabc8f84b6fe79dca5315c866450a538d58d6fa0571"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4df50fd84bfa4aa1eb7b52d48136066bfb64fabb7ceb62d4c318b45a296200b"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:433a650571a0d7766eb6f402e8f5930222997686c2ee01ded22f1d8fd46af9d4"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:232ee260475611cbf7adb554b81db6b5790b36e634fe2164f4ffcd2ca3e63a71"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:09049f7e71f15c9c9a03f597f77fc1f7b61ababd155c06c0d9e64d1453d945d7"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:424153d1d5f5a807f596a48cc88119f9fb3213ca7e38f57b8d15dcc964dd91f7"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:4f078fd1cf19c4ca63b8d1e0803df665310c8d5b644c5b02bf2465e8d6ef8f55"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f138d939e26e767537f891170b69a55a88038919f5c10d8865b67b8777fe4848"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9aafabc7e32942f85dcb46f007f447ab69024831575df97cae28c6ed127654d1"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-win32.whl", hash = "sha256:935e16ebf1a1998d8493979d858821a755503c9b8af572d9c450173d4b88868c"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-win_amd64.whl", hash = "sha256:306b3102cba278b5dfec6f5f7dc8b78416c403901510475c74913345b56c9e42"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fcb2fd00e58650ae206a6d5dbc83117240e622471aa5124733fbf2805eb8bda0"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b7a3e6b0a1eb218e3d870a94c76daaf65da46dca8f6888ea6542f94905c24d88"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a8d8e2888a857d8db3d98765a5ad23ab561241feaef68bbffc5a0bd9c142342"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85d50c011467f5ff6772c4059345968b854b72e07a0219030b7c3f68419eb7f7"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:93b395c1370629ccce8fb3e14cd5be2646d227bd32018c21f753c543e9a7e96b"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dbcee870c60d9835e5dce1456ab6b9d807e6669246357f4b321ef747b90fa43"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fffa5a5f317b1ec92e406a30a008929054cf3164d2324a3c465d0a0330273bf8"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:476702740a279744badbd177ae1c4a2d089ec128bd676861219d1f92078e4530"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5cd6d95fab5ff80e9dc9baedc9a926f62f74072d42d5804388d63b63bec0bb63"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:05027d32d7cf3e46cb8d04f8c984745ae01bd1bc7b3579f9dadf9b3cca735697"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:3d11831842250b4c1b26503a6e9c511fc03db096608b7c6af743818c421a3032"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:81b4b671b785ebb0b8aeabf2432e47072413d81db959eb8cfd8b6ab58c5799c6"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-win32.whl", hash = "sha256:e893bd4e014877174a59e032b0e99809c95ec61328a0e6bd9352c74a2f6111a8"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-win_amd64.whl", hash = "sha256:de6624e28eeffd01668803d28ae89e3d4e359b1bff8b60e4933e1cb3c6f86f18"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:909205324089a9ee59bee7ecbfa94595435118cca310fd62efdf13f225aa2965"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f31d6e47dc2b0f367f598f5629147ed056d7216c1788e25190fcfbfa02e749"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed84179914b2b7bb434c2322a6e7fd83daa681c97a050450511b66d917a129bb"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:67d1bf63efb4ba14ae6c6da99622e4a549e68fc3ee14d859bf611d8e6a61b3fa"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eed23ea41dd582d76f7a2ec7e09cbe5e9fec008f11a4799fa35ce44a3ebd283"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a654291132766efa2703058317749d7c69b69f02d89bac75703eaf7f775e20da"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1c26c5ef16d0ef3cabc5bc03e827e01b0a4afb5b4eaf8850b7cf740cee04a1d4"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b57e83d7986d3cbda6096974a9510eb53cb33ad9072288c87c820ba5eee3370e"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:153cc03b36f22cbde55aa6a5bbe99072a025567a54c48b262eb0da15d8cd7c83"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:83a857d99192936091f495826ae97497cd1873af213b1e069d56369fb182ab8e"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bb05a9bb22cbe9ad187ad268f86adf7e60df6083331fe59c01571b7b725212dd"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-win32.whl", hash = "sha256:3e282c5c25e32d96ed151e5460d2bf4ecb805ea64449197dd918e84e768016df"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-win_amd64.whl", hash = "sha256:c46dccfb04a9afd61a1b0e60bfefceff917f76da2c863f9b36b39248496d5c77"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:612ca9028c718f362c97f552e63d313cf1a70a616ef8532ddb0effdaf12ebef9"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:471b884d318e012f68d858476052742048918854f7dfe87d78e819f87a848ffb"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58ee63c35e99da887eb035c8d6d9e64fd298a0efc1460395297dd5cc281a6912"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0819bb63d2c5025a1fb9589f57ef82602687cef11081d6dfa6f2ce44606a1772"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6680ee18870bca1fbab1736c8203a965efaec119ab4c37821ad99add248ee08"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:713c498741b54debd3a10a5529e70b6ed85ca33c3e8629e24ae5cd8160b5a5f2"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:730837b8f63941065c9c955c44286aef0987fb084ffb3f55bf1e4fe07df62269"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9f4e38b2ea09214c8e7848a19391009a18c56a3640e1ba1a606b9e57aeb63404"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:457f1d6639e0345b717ae603c79bd087a35361ce68c1c308d154b80b841e5e7d"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:49a55aeb8ea625a87965a96e361bbb1ad67d0931bfb2a575f899c1064e70c2da"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9230058d8c9b1a04079afae4650fb67745f0f1c39db335728f64d48bd2c19246"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8798258bd556542dd9c6b8ebe62f9c5110c9dcdf97c57fb077e7b8b6d6da0826"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-win32.whl", hash = "sha256:ce8e3f4be46bcc63555863f70ab0035202b082b37e6f16876ef50e7bc4b47056"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-win_amd64.whl", hash = "sha256:2d982959ff628255808d895a67493f2dab0c3a9bfc65eeda0f00c8ae9962a1b3"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a46b227fab4420566ed24ee70d90076226d16fcf09c6ad4d428717efcf536446"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7eaa2ce5ea08cf5fddebb8c274c450e102f329f9e6966b6cd85aa671c48e5552"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f97f0083194d6e23b5ef6156ed0d5388c37847b298118199d7937ba26412a9e2"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6cab5cdbb0f8ee51d879d977b78f07068b585225ac656f3c081896c362e8f83"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cdb1b011a53ee71539e9dc655f268b111bac484db300da92829ed59e910a8fd0"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bf51bb761b281d20910b4b689c699ef98027845467daa5bb5dfdb53bd6ee404"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8ea462e3cebb121ff55002e9c8a9a0a3fd9b5bbbf688b4960f0a83c0172fb31"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:70bee21c245226ad0d637bf470472e2d487b86911b6d673a862127b934336ff4"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:253a3c223b944d691bf0abbd599f592ea3b36f0a71d2526833b1718f37eca5c2"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:a6549b53fc5c403dc556cb39b2ae94d73f9b113daa00438a660bb1dd5380ae4d"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1c685cd4abe61af1c26279ff04b9f567eb4d6c1ec7fb265af7481b1f153043aa"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7e25144219577491929d032a6c3ddd63c6cd7fa764af829a5637f798190d9b26"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-win32.whl", hash = "sha256:0b9925610d25405a8e6d83ff4f54fc2456a121adb0155999972f5edd6ba3efc8"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-win_amd64.whl", hash = "sha256:b243de483cfa02716053b0148d73558f4694f3c27b97fc1eaa97d7079563a14d"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:45a3d5b1d06750fd6a18c29b871494a2635670099ec7693e756a5885a4a70dbf"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8415ffebd6ca9eef3024763abc450f8659f1716d015bd563c537d01c7fbc3569"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace48db993aa4bd31c42de0fa8d38c94ad47405916d6b61f7a7168a48fb52ac1"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b07123334fe143bfe6fa4e3d4b732d647d5fd2cfb9ec7f2f76104b46fe9d20c6"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2af3efa73d296420ce6362789f5b1febf75d4aa159a479393f01549115509d5"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:baf57eede88d07a1eb04352d26fc58a4d97991ca3d8840f7c5d48691dec9f251"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:275d0ccdab9c3571bdb3e9acfab4497930aa584ff2766b035bb2f854deaf8b82"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:293da77bfcac3168fb35b27c242f97c1a05502435c0686ecbb8e2e4abcb3de26"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d6c2e5830705e4eeef33070ca4d5a24dfa221f28f2f540e5e6842c26e70b10b"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:11934bd78d97dd7e1a23a6222b5edd1e1b4d34e1ead5c846dc2b5c56fdc35ff5"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b802b6f0fbdcc3ab81b87f09b694dde91ab049f44d1d2c08c3dc8ea9a5950cfa"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7af871c5315eb829ecf4533c790461ea8f73b3bfd5f533b0467e479fdf6ddcfd"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d577dd4867b9e26cf60590e1f500990c8701a6e3cfbb9e644f4d0c0fb607028"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ed3dea2d1eca85fef5b8564ddd76dedb15a610c77d55d555b49d9f7c896b64b"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:91ec96f2c48e5bdeac9eea43a9bc9cc19acb2d2c59df0a13d5520dfc32457605"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7667ab423452754f36ba8fb41e006a46baace9c94e2aca2a745689b9f2753dfb"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:653583b1f3b088d106f180d6f02c90917ecd669ec956b62903a05df4a7f44863"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ef3dd0cbdf2f0171caab90389af0ede068ec802bf46c6a77f14e6edc86671bc"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11b1833ee8ff8d5df39a34a895e060b57bd81e05ea68822bc60476daff4ce1c8"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8a3195639e6393b9d4aafe736036881ff86b6be5855d4bf7d9f5c31637181ec3"}, ] [package.dependencies] @@ -1916,13 +2071,13 @@ typing-inspect = ">=0.4.0,<1" [[package]] name = "db-dtypes" -version = "1.2.0" +version = "1.3.0" description = "Pandas Data Types for SQL systems (BigQuery, Spanner)" optional = false python-versions = ">=3.7" files = [ - {file = "db-dtypes-1.2.0.tar.gz", hash = "sha256:3531bb1fb8b5fbab33121fe243ccc2ade16ab2524f4c113b05cc702a1908e6ea"}, - {file = "db_dtypes-1.2.0-py2.py3-none-any.whl", hash = "sha256:6320bddd31d096447ef749224d64aab00972ed20e4392d86f7d8b81ad79f7ff0"}, + {file = "db_dtypes-1.3.0-py2.py3-none-any.whl", hash = "sha256:7e65c59f849ccbe6f7bc4d0253edcc212a7907662906921caba3e4aadd0bc277"}, + {file = "db_dtypes-1.3.0.tar.gz", hash = "sha256:7bcbc8858b07474dc85b77bb2f3ae488978d1336f5ea73b58c39d9118bc3e91b"}, ] [package.dependencies] @@ -1970,26 +2125,6 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] -[[package]] -name = "dnspython" -version = "2.6.1" -description = "DNS toolkit" -optional = false -python-versions = ">=3.8" -files = [ - {file = "dnspython-2.6.1-py3-none-any.whl", hash = "sha256:5ef3b9680161f6fa89daf8ad451b5f1a33b18ae8a1c6778cdf4b43f08c0a6e50"}, - {file = "dnspython-2.6.1.tar.gz", hash = "sha256:e8f0f9c23a7b7cb99ded64e6c3a6f3e701d78f50c55e002b839dea7225cff7cc"}, -] - -[package.extras] -dev = ["black (>=23.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "mypy (>=1.8)", "pylint (>=3)", "pytest (>=7.4)", "pytest-cov (>=4.1.0)", "sphinx (>=7.2.0)", "twine (>=4.0.0)", "wheel (>=0.42.0)"] -dnssec = ["cryptography (>=41)"] -doh = ["h2 (>=4.1.0)", "httpcore (>=1.0.0)", "httpx (>=0.26.0)"] -doq = ["aioquic (>=0.9.25)"] -idna = ["idna (>=3.6)"] -trio = ["trio (>=0.23)"] -wmi = ["wmi (>=1.5.1)"] - [[package]] name = "docstring-parser" version = "0.16" @@ -2076,37 +2211,60 @@ files = [ [[package]] name = "duckduckgo-search" -version = "6.2.1" +version = "6.2.10" description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine." optional = false python-versions = ">=3.8" files = [ - {file = "duckduckgo_search-6.2.1-py3-none-any.whl", hash = "sha256:1a03f799b85fdfa08d5e6478624683f373b9dc35e6f145544b9cab72a4f575fa"}, - {file = "duckduckgo_search-6.2.1.tar.gz", hash = "sha256:d664ec096193e3fb43bdfae4b0ad9c04e44094b58f41998adcdd20a86ee1ed74"}, + {file = "duckduckgo_search-6.2.10-py3-none-any.whl", hash = "sha256:266c1528dcbc90931b7c800a2c1041a0cb447c83c485414d77a7e443be717ed6"}, + {file = "duckduckgo_search-6.2.10.tar.gz", hash = "sha256:53057368480ca496fc4e331a34648124711580cf43fbb65336eaa6fd2ee37cec"}, ] [package.dependencies] click = ">=8.1.7" -pyreqwest-impersonate = ">=0.5.0" +primp = ">=0.6.1" [package.extras] -dev = ["mypy (>=1.10.1)", "pytest (>=8.2.2)", "pytest-asyncio (>=0.23.7)", "ruff (>=0.5.2)"] +dev = ["mypy (>=1.11.1)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.6.1)"] lxml = ["lxml (>=5.2.2)"] [[package]] -name = "email-validator" -version = "2.2.0" -description = "A robust email address syntax and deliverability validation library." +name = "elastic-transport" +version = "8.15.0" +description = "Transport classes and utilities shared among Python Elastic client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631"}, - {file = "email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7"}, + {file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"}, + {file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.26.2,<3" + +[package.extras] +develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"] + +[[package]] +name = "elasticsearch" +version = "8.14.0" +description = "Python client for Elasticsearch" +optional = false +python-versions = ">=3.7" +files = [ + {file = "elasticsearch-8.14.0-py3-none-any.whl", hash = "sha256:cef8ef70a81af027f3da74a4f7d9296b390c636903088439087b8262a468c130"}, + {file = "elasticsearch-8.14.0.tar.gz", hash = "sha256:aa2490029dd96f4015b333c1827aa21fd6c0a4d223b00dfb0fe933b8d09a511b"}, ] [package.dependencies] -dnspython = ">=2.0.0" -idna = ">=2.0.0" +elastic-transport = ">=8.13,<9" + +[package.extras] +async = ["aiohttp (>=3,<4)"] +orjson = ["orjson (>=3)"] +requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"] +vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"] [[package]] name = "emoji" @@ -2173,45 +2331,23 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.111.1" +version = "0.112.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.111.1-py3-none-any.whl", hash = "sha256:4f51cfa25d72f9fbc3280832e84b32494cf186f50158d364a8765aabf22587bf"}, - {file = "fastapi-0.111.1.tar.gz", hash = "sha256:ddd1ac34cb1f76c2e2d7f8545a4bcb5463bce4834e81abf0b189e0c359ab2413"}, + {file = "fastapi-0.112.1-py3-none-any.whl", hash = "sha256:bcbd45817fc2a1cd5da09af66815b84ec0d3d634eb173d1ab468ae3103e183e4"}, + {file = "fastapi-0.112.1.tar.gz", hash = "sha256:b2537146f8c23389a7faa8b03d0bd38d4986e6983874557d95eed2acc46448ef"}, ] [package.dependencies] -email_validator = ">=2.0.0" -fastapi-cli = ">=0.0.2" -httpx = ">=0.23.0" -jinja2 = ">=2.11.2" pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -python-multipart = ">=0.0.7" -starlette = ">=0.37.2,<0.38.0" +starlette = ">=0.37.2,<0.39.0" typing-extensions = ">=4.8.0" -uvicorn = {version = ">=0.12.0", extras = ["standard"]} [package.extras] -all = ["email_validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] - -[[package]] -name = "fastapi-cli" -version = "0.0.4" -description = "Run and manage FastAPI apps from the command line with FastAPI CLI. 🚀" -optional = false -python-versions = ">=3.8" -files = [ - {file = "fastapi_cli-0.0.4-py3-none-any.whl", hash = "sha256:a2552f3a7ae64058cdbb530be6fa6dbfc975dc165e4fa66d224c3d396e25e809"}, - {file = "fastapi_cli-0.0.4.tar.gz", hash = "sha256:e2e9ffaffc1f7767f488d6da34b6f5a377751c996f397902eb6abb99a67bde32"}, -] - -[package.dependencies] -typer = ">=0.12.3" - -[package.extras] -standard = ["fastapi", "uvicorn[standard] (>=0.15.0)"] +all = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "fastavro" @@ -3023,13 +3159,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"] [[package]] name = "google-cloud-resource-manager" -version = "1.12.4" +version = "1.12.5" description = "Google Cloud Resource Manager API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-resource-manager-1.12.4.tar.gz", hash = "sha256:3eda914a925e92465ef80faaab7e0f7a9312d486dd4e123d2c76e04bac688ff0"}, - {file = "google_cloud_resource_manager-1.12.4-py2.py3-none-any.whl", hash = "sha256:0b6663585f7f862166c0fb4c55fdda721fce4dc2dc1d5b52d03ee4bf2653a85f"}, + {file = "google_cloud_resource_manager-1.12.5-py2.py3-none-any.whl", hash = "sha256:2708a718b45c79464b7b21559c701b5c92e6b0b1ab2146d0a256277a623dc175"}, + {file = "google_cloud_resource_manager-1.12.5.tar.gz", hash = "sha256:b7af4254401ed4efa3aba3a929cb3ddb803fa6baf91a78485e45583597de5891"}, ] [package.dependencies] @@ -3166,13 +3302,13 @@ dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "py [[package]] name = "google-resumable-media" -version = "2.7.1" +version = "2.7.2" description = "Utilities for Google Media Downloads and Resumable Uploads" optional = false python-versions = ">=3.7" files = [ - {file = "google-resumable-media-2.7.1.tar.gz", hash = "sha256:eae451a7b2e2cdbaaa0fd2eb00cc8a1ee5e95e16b55597359cbc3d27d7d90e33"}, - {file = "google_resumable_media-2.7.1-py2.py3-none-any.whl", hash = "sha256:103ebc4ba331ab1bfdac0250f8033627a2cd7cde09e7ccff9181e31ba4315b2c"}, + {file = "google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa"}, + {file = "google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0"}, ] [package.dependencies] @@ -3289,151 +3425,137 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [[package]] name = "grpcio" -version = "1.62.2" +version = "1.63.0" description = "HTTP/2-based RPC framework" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "grpcio-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:66344ea741124c38588a664237ac2fa16dfd226964cca23ddc96bd4accccbde5"}, - {file = "grpcio-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5dab7ac2c1e7cb6179c6bfad6b63174851102cbe0682294e6b1d6f0981ad7138"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:3ad00f3f0718894749d5a8bb0fa125a7980a2f49523731a9b1fabf2b3522aa43"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e72ddfee62430ea80133d2cbe788e0d06b12f865765cb24a40009668bd8ea05"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53d3a59a10af4c2558a8e563aed9f256259d2992ae0d3037817b2155f0341de1"}, - {file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1511a303f8074f67af4119275b4f954189e8313541da7b88b1b3a71425cdb10"}, - {file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b94d41b7412ef149743fbc3178e59d95228a7064c5ab4760ae82b562bdffb199"}, - {file = "grpcio-1.62.2-cp310-cp310-win32.whl", hash = "sha256:a75af2fc7cb1fe25785be7bed1ab18cef959a376cdae7c6870184307614caa3f"}, - {file = "grpcio-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:80407bc007754f108dc2061e37480238b0dc1952c855e86a4fc283501ee6bb5d"}, - {file = "grpcio-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:c1624aa686d4b36790ed1c2e2306cc3498778dffaf7b8dd47066cf819028c3ad"}, - {file = "grpcio-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:1c1bb80299bdef33309dff03932264636450c8fdb142ea39f47e06a7153d3063"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:db068bbc9b1fa16479a82e1ecf172a93874540cb84be69f0b9cb9b7ac3c82670"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2cc8a308780edbe2c4913d6a49dbdb5befacdf72d489a368566be44cadaef1a"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0695ae31a89f1a8fc8256050329a91a9995b549a88619263a594ca31b76d756"}, - {file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88b4f9ee77191dcdd8810241e89340a12cbe050be3e0d5f2f091c15571cd3930"}, - {file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a0204532aa2f1afd467024b02b4069246320405bc18abec7babab03e2644e75"}, - {file = "grpcio-1.62.2-cp311-cp311-win32.whl", hash = "sha256:6e784f60e575a0de554ef9251cbc2ceb8790914fe324f11e28450047f264ee6f"}, - {file = "grpcio-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:112eaa7865dd9e6d7c0556c8b04ae3c3a2dc35d62ad3373ab7f6a562d8199200"}, - {file = "grpcio-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:65034473fc09628a02fb85f26e73885cf1ed39ebd9cf270247b38689ff5942c5"}, - {file = "grpcio-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d2c1771d0ee3cf72d69bb5e82c6a82f27fbd504c8c782575eddb7839729fbaad"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:3abe6838196da518863b5d549938ce3159d809218936851b395b09cad9b5d64a"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5ffeb269f10cedb4f33142b89a061acda9f672fd1357331dbfd043422c94e9e"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:404d3b4b6b142b99ba1cff0b2177d26b623101ea2ce51c25ef6e53d9d0d87bcc"}, - {file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:262cda97efdabb20853d3b5a4c546a535347c14b64c017f628ca0cc7fa780cc6"}, - {file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17708db5b11b966373e21519c4c73e5a750555f02fde82276ea2a267077c68ad"}, - {file = "grpcio-1.62.2-cp312-cp312-win32.whl", hash = "sha256:b7ec9e2f8ffc8436f6b642a10019fc513722858f295f7efc28de135d336ac189"}, - {file = "grpcio-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:aa787b83a3cd5e482e5c79be030e2b4a122ecc6c5c6c4c42a023a2b581fdf17b"}, - {file = "grpcio-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cfd23ad29bfa13fd4188433b0e250f84ec2c8ba66b14a9877e8bce05b524cf54"}, - {file = "grpcio-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:af15e9efa4d776dfcecd1d083f3ccfb04f876d613e90ef8432432efbeeac689d"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f4aa94361bb5141a45ca9187464ae81a92a2a135ce2800b2203134f7a1a1d479"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82af3613a219512a28ee5c95578eb38d44dd03bca02fd918aa05603c41018051"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ddaf53474e8caeb29eb03e3202f9d827ad3110475a21245f3c7712022882a9"}, - {file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79b518c56dddeec79e5500a53d8a4db90da995dfe1738c3ac57fe46348be049"}, - {file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a5eb4844e5e60bf2c446ef38c5b40d7752c6effdee882f716eb57ae87255d20a"}, - {file = "grpcio-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:aaae70364a2d1fb238afd6cc9fcb10442b66e397fd559d3f0968d28cc3ac929c"}, - {file = "grpcio-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:1bcfe5070e4406f489e39325b76caeadab28c32bf9252d3ae960c79935a4cc36"}, - {file = "grpcio-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:da6a7b6b938c15fa0f0568e482efaae9c3af31963eec2da4ff13a6d8ec2888e4"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:41955b641c34db7d84db8d306937b72bc4968eef1c401bea73081a8d6c3d8033"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c772f225483905f675cb36a025969eef9712f4698364ecd3a63093760deea1bc"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07ce1f775d37ca18c7a141300e5b71539690efa1f51fe17f812ca85b5e73262f"}, - {file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:26f415f40f4a93579fd648f48dca1c13dfacdfd0290f4a30f9b9aeb745026811"}, - {file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:db707e3685ff16fc1eccad68527d072ac8bdd2e390f6daa97bc394ea7de4acea"}, - {file = "grpcio-1.62.2-cp38-cp38-win32.whl", hash = "sha256:589ea8e75de5fd6df387de53af6c9189c5231e212b9aa306b6b0d4f07520fbb9"}, - {file = "grpcio-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:3c3ed41f4d7a3aabf0f01ecc70d6b5d00ce1800d4af652a549de3f7cf35c4abd"}, - {file = "grpcio-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:162ccf61499c893831b8437120600290a99c0bc1ce7b51f2c8d21ec87ff6af8b"}, - {file = "grpcio-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:f27246d7da7d7e3bd8612f63785a7b0c39a244cf14b8dd9dd2f2fab939f2d7f1"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:2507006c8a478f19e99b6fe36a2464696b89d40d88f34e4b709abe57e1337467"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a90ac47a8ce934e2c8d71e317d2f9e7e6aaceb2d199de940ce2c2eb611b8c0f4"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99701979bcaaa7de8d5f60476487c5df8f27483624f1f7e300ff4669ee44d1f2"}, - {file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:af7dc3f7a44f10863b1b0ecab4078f0a00f561aae1edbd01fd03ad4dcf61c9e9"}, - {file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fa63245271920786f4cb44dcada4983a3516be8f470924528cf658731864c14b"}, - {file = "grpcio-1.62.2-cp39-cp39-win32.whl", hash = "sha256:c6ad9c39704256ed91a1cffc1379d63f7d0278d6a0bad06b0330f5d30291e3a3"}, - {file = "grpcio-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:16da954692fd61aa4941fbeda405a756cd96b97b5d95ca58a92547bba2c1624f"}, - {file = "grpcio-1.62.2.tar.gz", hash = "sha256:c77618071d96b7a8be2c10701a98537823b9c65ba256c0b9067e0594cdbd954d"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.62.2)"] + {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, + {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, + {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, + {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, + {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, + {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, + {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, + {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, + {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, + {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, + {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, + {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, + {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, + {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, + {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, + {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, + {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, + {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, + {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, + {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, + {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.63.0)"] [[package]] name = "grpcio-status" -version = "1.62.2" +version = "1.62.3" description = "Status proto mapping for gRPC" optional = false python-versions = ">=3.6" files = [ - {file = "grpcio-status-1.62.2.tar.gz", hash = "sha256:62e1bfcb02025a1cd73732a2d33672d3e9d0df4d21c12c51e0bbcaf09bab742a"}, - {file = "grpcio_status-1.62.2-py3-none-any.whl", hash = "sha256:206ddf0eb36bc99b033f03b2c8e95d319f0044defae9b41ae21408e7e0cda48f"}, + {file = "grpcio-status-1.62.3.tar.gz", hash = "sha256:289bdd7b2459794a12cf95dc0cb727bd4a1742c37bd823f760236c937e53a485"}, + {file = "grpcio_status-1.62.3-py3-none-any.whl", hash = "sha256:f9049b762ba8de6b1086789d8315846e094edac2c50beaf462338b301a8fd4b8"}, ] [package.dependencies] googleapis-common-protos = ">=1.5.5" -grpcio = ">=1.62.2" +grpcio = ">=1.62.3" protobuf = ">=4.21.6" [[package]] name = "grpcio-tools" -version = "1.62.2" +version = "1.62.3" description = "Protobuf code generator for gRPC" optional = false python-versions = ">=3.7" files = [ - {file = "grpcio-tools-1.62.2.tar.gz", hash = "sha256:5fd5e1582b678e6b941ee5f5809340be5e0724691df5299aae8226640f94e18f"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:1679b4903aed2dc5bd8cb22a452225b05dc8470a076f14fd703581efc0740cdb"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:9d41e0e47dd075c075bb8f103422968a65dd0d8dc8613288f573ae91eb1053ba"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:987e774f74296842bbffd55ea8826370f70c499e5b5f71a8cf3103838b6ee9c3"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40cd4eeea4b25bcb6903b82930d579027d034ba944393c4751cdefd9c49e6989"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6746bc823958499a3cf8963cc1de00072962fb5e629f26d658882d3f4c35095"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2ed775e844566ce9ce089be9a81a8b928623b8ee5820f5e4d58c1a9d33dfc5ae"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bdc5dd3f57b5368d5d661d5d3703bcaa38bceca59d25955dff66244dbc987271"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-win32.whl", hash = "sha256:3a8d6f07e64c0c7756f4e0c4781d9d5a2b9cc9cbd28f7032a6fb8d4f847d0445"}, - {file = "grpcio_tools-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:e33b59fb3efdddeb97ded988a871710033e8638534c826567738d3edce528752"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:472505d030135d73afe4143b0873efe0dcb385bd6d847553b4f3afe07679af00"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:ec674b4440ef4311ac1245a709e87b36aca493ddc6850eebe0b278d1f2b6e7d1"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:184b4174d4bd82089d706e8223e46c42390a6ebac191073b9772abc77308f9fa"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c195d74fe98541178ece7a50dad2197d43991e0f77372b9a88da438be2486f12"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a34d97c62e61bfe9e6cff0410fe144ac8cca2fc979ad0be46b7edf026339d161"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb8453ae83a1db2452b7fe0f4b78e4a8dd32be0f2b2b73591ae620d4d784d3d"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f989e5cebead3ae92c6abf6bf7b19949e1563a776aea896ac5933f143f0c45d"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-win32.whl", hash = "sha256:c48fabe40b9170f4e3d7dd2c252e4f1ff395dc24e49ac15fc724b1b6f11724da"}, - {file = "grpcio_tools-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:8c616d0ad872e3780693fce6a3ac8ef00fc0963e6d7815ce9dcfae68ba0fc287"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:10cc3321704ecd17c93cf68c99c35467a8a97ffaaed53207e9b2da6ae0308ee1"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:9be84ff6d47fd61462be7523b49d7ba01adf67ce4e1447eae37721ab32464dd8"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d82f681c9a9d933a9d8068e8e382977768e7779ddb8870fa0cf918d8250d1532"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04c607029ae3660fb1624ed273811ffe09d57d84287d37e63b5b802a35897329"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72b61332f1b439c14cbd3815174a8f1d35067a02047c32decd406b3a09bb9890"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8214820990d01b52845f9fbcb92d2b7384a0c321b303e3ac614c219dc7d1d3af"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:462e0ab8dd7c7b70bfd6e3195eebc177549ede5cf3189814850c76f9a340d7ce"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-win32.whl", hash = "sha256:fa107460c842e4c1a6266150881694fefd4f33baa544ea9489601810c2210ef8"}, - {file = "grpcio_tools-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:759c60f24c33a181bbbc1232a6752f9b49fbb1583312a4917e2b389fea0fb0f2"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:45db5da2bcfa88f2b86b57ef35daaae85c60bd6754a051d35d9449c959925b57"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ab84bae88597133f6ea7a2bdc57b2fda98a266fe8d8d4763652cbefd20e73ad7"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7a49bccae1c7d154b78e991885c3111c9ad8c8fa98e91233de425718f47c6139"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7e439476b29d6dac363b321781a113794397afceeb97dad85349db5f1cb5e9a"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ea369c4d1567d1acdf69c8ea74144f4ccad9e545df7f9a4fc64c94fa7684ba3"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f955702dc4b530696375251319d05223b729ed24e8673c2129f7a75d2caefbb"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3708a747aa4b6b505727282ca887041174e146ae030ebcadaf4c1d346858df62"}, - {file = "grpcio_tools-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2ce149ea55eadb486a7fb75a20f63ef3ac065ee6a0240ed25f3549ce7954c653"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:58cbb24b3fa6ae35aa9c210fcea3a51aa5fef0cd25618eb4fd94f746d5a9b703"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:6413581e14a80e0b4532577766cf0586de4dd33766a31b3eb5374a746771c07d"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:47117c8a7e861382470d0e22d336e5a91fdc5f851d1db44fa784b9acea190d87"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f1ba79a253df9e553d20319c615fa2b429684580fa042dba618d7f6649ac7e4"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04a394cf5e51ba9be412eb9f6c482b6270bd81016e033e8eb7d21b8cc28fe8b5"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3c53b221378b035ae2f1881cbc3aca42a6075a8e90e1a342c2f205eb1d1aa6a1"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c384c838b34d1b67068e51b5bbe49caa6aa3633acd158f1ab16b5da8d226bc53"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-win32.whl", hash = "sha256:19ea69e41c3565932aa28a202d1875ec56786aea46a2eab54a3b28e8a27f9517"}, - {file = "grpcio_tools-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:1d768a5c07279a4c461ebf52d0cec1c6ca85c6291c71ec2703fe3c3e7e28e8c4"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:5b07b5874187e170edfbd7aa2ca3a54ebf3b2952487653e8c0b0d83601c33035"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d58389fe8be206ddfb4fa703db1e24c956856fcb9a81da62b13577b3a8f7fda7"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:7d8b4e00c3d7237b92260fc18a561cd81f1da82e8be100db1b7d816250defc66"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe08d2038f2b7c53259b5c49e0ad08c8e0ce2b548d8185993e7ef67e8592cca"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19216e1fb26dbe23d12a810517e1b3fbb8d4f98b1a3fbebeec9d93a79f092de4"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b8574469ecc4ff41d6bb95f44e0297cdb0d95bade388552a9a444db9cd7485cd"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4f6f32d39283ea834a493fccf0ebe9cfddee7577bdcc27736ad4be1732a36399"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-win32.whl", hash = "sha256:76eb459bdf3fb666e01883270beee18f3f11ed44488486b61cd210b4e0e17cc1"}, - {file = "grpcio_tools-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:217c2ee6a7ce519a55958b8622e21804f6fdb774db08c322f4c9536c35fdce7c"}, -] - -[package.dependencies] -grpcio = ">=1.62.2" + {file = "grpcio-tools-1.62.3.tar.gz", hash = "sha256:7c7136015c3d62c3eef493efabaf9e3380e3e66d24ee8e94c01cb71377f57833"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2f968b049c2849540751ec2100ab05e8086c24bead769ca734fdab58698408c1"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:0a8c0c4724ae9c2181b7dbc9b186df46e4f62cb18dc184e46d06c0ebeccf569e"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5782883a27d3fae8c425b29a9d3dcf5f47d992848a1b76970da3b5a28d424b26"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3d812daffd0c2d2794756bd45a353f89e55dc8f91eb2fc840c51b9f6be62667"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b47d0dda1bdb0a0ba7a9a6de88e5a1ed61f07fad613964879954961e36d49193"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ca246dffeca0498be9b4e1ee169b62e64694b0f92e6d0be2573e65522f39eea9"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-win32.whl", hash = "sha256:6a56d344b0bab30bf342a67e33d386b0b3c4e65868ffe93c341c51e1a8853ca5"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-win_amd64.whl", hash = "sha256:710fecf6a171dcbfa263a0a3e7070e0df65ba73158d4c539cec50978f11dad5d"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:703f46e0012af83a36082b5f30341113474ed0d91e36640da713355cd0ea5d23"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:7cc83023acd8bc72cf74c2edbe85b52098501d5b74d8377bfa06f3e929803492"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ff7d58a45b75df67d25f8f144936a3e44aabd91afec833ee06826bd02b7fbe7"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f2483ea232bd72d98a6dc6d7aefd97e5bc80b15cd909b9e356d6f3e326b6e43"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:962c84b4da0f3b14b3cdb10bc3837ebc5f136b67d919aea8d7bb3fd3df39528a"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8ad0473af5544f89fc5a1ece8676dd03bdf160fb3230f967e05d0f4bf89620e3"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-win32.whl", hash = "sha256:db3bc9fa39afc5e4e2767da4459df82b095ef0cab2f257707be06c44a1c2c3e5"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-win_amd64.whl", hash = "sha256:e0898d412a434e768a0c7e365acabe13ff1558b767e400936e26b5b6ed1ee51f"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d102b9b21c4e1e40af9a2ab3c6d41afba6bd29c0aa50ca013bf85c99cdc44ac5"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:0a52cc9444df978438b8d2332c0ca99000521895229934a59f94f37ed896b133"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141d028bf5762d4a97f981c501da873589df3f7e02f4c1260e1921e565b376fa"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47a5c093ab256dec5714a7a345f8cc89315cb57c298b276fa244f37a0ba507f0"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:f6831fdec2b853c9daa3358535c55eed3694325889aa714070528cf8f92d7d6d"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e02d7c1a02e3814c94ba0cfe43d93e872c758bd8fd5c2797f894d0c49b4a1dfc"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-win32.whl", hash = "sha256:b881fd9505a84457e9f7e99362eeedd86497b659030cf57c6f0070df6d9c2b9b"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-win_amd64.whl", hash = "sha256:11c625eebefd1fd40a228fc8bae385e448c7e32a6ae134e43cf13bbc23f902b7"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ec6fbded0c61afe6f84e3c2a43e6d656791d95747d6d28b73eff1af64108c434"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:bfda6ee8990997a9df95c5606f3096dae65f09af7ca03a1e9ca28f088caca5cf"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b77f9f9cee87cd798f0fe26b7024344d1b03a7cd2d2cba7035f8433b13986325"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e02d3b96f2d0e4bab9ceaa30f37d4f75571e40c6272e95364bff3125a64d184"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1da38070738da53556a4b35ab67c1b9884a5dd48fa2f243db35dc14079ea3d0c"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ace43b26d88a58dcff16c20d23ff72b04d0a415f64d2820f4ff06b1166f50557"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-win_amd64.whl", hash = "sha256:350a80485e302daaa95d335a931f97b693e170e02d43767ab06552c708808950"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:c3a1ac9d394f8e229eb28eec2e04b9a6f5433fa19c9d32f1cb6066e3c5114a1d"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:11f363570dea661dde99e04a51bd108a5807b5df32a6f8bdf4860e34e94a4dbf"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9ad9950119d8ae27634e68b7663cc8d340ae535a0f80d85a55e56a6973ab1f"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c5d22b252dcef11dd1e0fbbe5bbfb9b4ae048e8880d33338215e8ccbdb03edc"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:27cd9ef5c5d68d5ed104b6dcb96fe9c66b82050e546c9e255716903c3d8f0373"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f4b1615adf67bd8bb71f3464146a6f9949972d06d21a4f5e87e73f6464d97f57"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-win32.whl", hash = "sha256:e18e15287c31baf574fcdf8251fb7f997d64e96c6ecf467906e576da0a079af6"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-win_amd64.whl", hash = "sha256:6c3064610826f50bd69410c63101954676edc703e03f9e8f978a135f1aaf97c1"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:8e62cc7164b0b7c5128e637e394eb2ef3db0e61fc798e80c301de3b2379203ed"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c8ad5cce554e2fcaf8842dee5d9462583b601a3a78f8b76a153c38c963f58c10"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec279dcf3518201fc592c65002754f58a6b542798cd7f3ecd4af086422f33f29"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c989246c2aebc13253f08be32538a4039a64e12d9c18f6d662d7aee641dc8b5"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ca4f5eeadbb57cf03317d6a2857823239a63a59cc935f5bd6cf6e8b7af7a7ecc"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0cb3a3436ac119cbd37a7d3331d9bdf85dad21a6ac233a3411dff716dcbf401e"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-win32.whl", hash = "sha256:3eae6ea76d62fcac091e1f15c2dcedf1dc3f114f8df1a972a8a0745e89f4cf61"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-win_amd64.whl", hash = "sha256:eec73a005443061f4759b71a056f745e3b000dc0dc125c9f20560232dfbcbd14"}, +] + +[package.dependencies] +grpcio = ">=1.62.3" protobuf = ">=4.21.6,<5.0dev" setuptools = "*" @@ -3798,37 +3920,41 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.1.0" +version = "8.0.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, - {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, + {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, + {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" -version = "6.4.0" +version = "6.4.4" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, - {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, + {file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"}, + {file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] [[package]] name = "iniconfig" @@ -3939,6 +4065,41 @@ files = [ [package.dependencies] ply = "*" +[[package]] +name = "jsonschema" +version = "4.23.0" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + [[package]] name = "kaleido" version = "0.2.1" @@ -4069,28 +4230,28 @@ files = [ [[package]] name = "kombu" -version = "5.3.7" +version = "5.4.0" description = "Messaging library for Python." optional = false python-versions = ">=3.8" files = [ - {file = "kombu-5.3.7-py3-none-any.whl", hash = "sha256:5634c511926309c7f9789f1433e9ed402616b56836ef9878f01bd59267b4c7a9"}, - {file = "kombu-5.3.7.tar.gz", hash = "sha256:011c4cd9a355c14a1de8d35d257314a1d2456d52b7140388561acac3cf1a97bf"}, + {file = "kombu-5.4.0-py3-none-any.whl", hash = "sha256:c8dd99820467610b4febbc7a9e8a0d3d7da2d35116b67184418b51cc520ea6b6"}, + {file = "kombu-5.4.0.tar.gz", hash = "sha256:ad200a8dbdaaa2bbc5f26d2ee7d707d9a1fded353a0f4bd751ce8c7d9f449c60"}, ] [package.dependencies] amqp = ">=5.1.1,<6.0.0" -vine = "*" +vine = "5.1.0" [package.extras] azureservicebus = ["azure-servicebus (>=7.10.0)"] azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"] confluentkafka = ["confluent-kafka (>=2.2.0)"] -consul = ["python-consul2"] +consul = ["python-consul2 (==0.1.5)"] librabbitmq = ["librabbitmq (>=2.0.0)"] mongodb = ["pymongo (>=4.1.1)"] -msgpack = ["msgpack"] -pyro = ["pyro4"] +msgpack = ["msgpack (==1.0.8)"] +pyro = ["pyro4 (==4.82)"] qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] slmq = ["softlayer-messaging (>=1.0.3)"] @@ -4141,13 +4302,13 @@ six = "*" [[package]] name = "langfuse" -version = "2.39.3" +version = "2.44.0" description = "A client library for accessing langfuse" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langfuse-2.39.3-py3-none-any.whl", hash = "sha256:24b12cbb23f866b22706c1ea9631781f99fe37b0b15889d241198c4d1c07516b"}, - {file = "langfuse-2.39.3.tar.gz", hash = "sha256:4d2df8f9344572370703db103ddf97176df518699593254e6d6c2b8ca3bf2f12"}, + {file = "langfuse-2.44.0-py3-none-any.whl", hash = "sha256:adb73400a6ad6d597cc95c31381c82f81face3d5fb69391181f224a26f7e8562"}, + {file = "langfuse-2.44.0.tar.gz", hash = "sha256:dfa5378ff7022ae9fe5b8b842c0365347c98f9ef2b772dcee6a93a45442de28c"}, ] [package.dependencies] @@ -4166,16 +4327,17 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langsmith" -version = "0.1.93" +version = "0.1.101" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.93-py3-none-any.whl", hash = "sha256:811210b9d5f108f36431bd7b997eb9476a9ecf5a2abd7ddbb606c1cdcf0f43ce"}, - {file = "langsmith-0.1.93.tar.gz", hash = "sha256:285b6ad3a54f50fa8eb97b5f600acc57d0e37e139dd8cf2111a117d0435ba9b4"}, + {file = "langsmith-0.1.101-py3-none-any.whl", hash = "sha256:572e2c90709cda1ad837ac86cedda7295f69933f2124c658a92a35fb890477cc"}, + {file = "langsmith-0.1.101.tar.gz", hash = "sha256:caf4d95f314bb6cd3c4e0632eed821fd5cd5d0f18cb824772fce6d7a9113895b"}, ] [package.dependencies] +httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, @@ -4215,96 +4377,157 @@ files = [ [[package]] name = "lxml" -version = "5.1.0" +version = "5.3.0" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." optional = false python-versions = ">=3.6" files = [ - {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, - {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"}, - {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"}, - {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"}, - {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"}, - {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"}, - {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"}, - {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"}, - {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"}, - {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"}, - {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"}, - {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"}, - {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"}, - {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"}, - {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"}, - {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"}, - {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"}, - {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"}, - {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"}, - {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"}, - {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"}, - {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"}, - {file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"}, - {file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"}, - {file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"}, - {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"}, - {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"}, - {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"}, - {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"}, - {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"}, - {file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"}, - {file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"}, - {file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"}, - {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"}, - {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"}, - {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"}, - {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"}, - {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"}, - {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"}, - {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"}, - {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"}, - {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"}, - {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"}, - {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"}, - {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"}, - {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"}, - {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"}, - {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"}, - {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"}, - {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"}, - {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"}, - {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"}, - {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"}, - {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"}, - {file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"}, - {file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"}, - {file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"}, - {file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"}, - {file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"}, - {file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"}, - {file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"}, - {file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"}, - {file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"}, - {file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"}, - {file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"}, - {file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"}, - {file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"}, - {file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"}, - {file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"}, + {file = "lxml-5.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd36439be765e2dde7660212b5275641edbc813e7b24668831a5c8ac91180656"}, + {file = "lxml-5.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ae5fe5c4b525aa82b8076c1a59d642c17b6e8739ecf852522c6321852178119d"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:501d0d7e26b4d261fca8132854d845e4988097611ba2531408ec91cf3fd9d20a"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66442c2546446944437df74379e9cf9e9db353e61301d1a0e26482f43f0dd8"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e41506fec7a7f9405b14aa2d5c8abbb4dbbd09d88f9496958b6d00cb4d45330"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f7d4a670107d75dfe5ad080bed6c341d18c4442f9378c9f58e5851e86eb79965"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41ce1f1e2c7755abfc7e759dc34d7d05fd221723ff822947132dc934d122fe22"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:44264ecae91b30e5633013fb66f6ddd05c006d3e0e884f75ce0b4755b3e3847b"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:3c174dc350d3ec52deb77f2faf05c439331d6ed5e702fc247ccb4e6b62d884b7"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:2dfab5fa6a28a0b60a20638dc48e6343c02ea9933e3279ccb132f555a62323d8"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b1c8c20847b9f34e98080da785bb2336ea982e7f913eed5809e5a3c872900f32"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2c86bf781b12ba417f64f3422cfc302523ac9cd1d8ae8c0f92a1c66e56ef2e86"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c162b216070f280fa7da844531169be0baf9ccb17263cf5a8bf876fcd3117fa5"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:36aef61a1678cb778097b4a6eeae96a69875d51d1e8f4d4b491ab3cfb54b5a03"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f65e5120863c2b266dbcc927b306c5b78e502c71edf3295dfcb9501ec96e5fc7"}, + {file = "lxml-5.3.0-cp310-cp310-win32.whl", hash = "sha256:ef0c1fe22171dd7c7c27147f2e9c3e86f8bdf473fed75f16b0c2e84a5030ce80"}, + {file = "lxml-5.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:052d99051e77a4f3e8482c65014cf6372e61b0a6f4fe9edb98503bb5364cfee3"}, + {file = "lxml-5.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:74bcb423462233bc5d6066e4e98b0264e7c1bed7541fff2f4e34fe6b21563c8b"}, + {file = "lxml-5.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a3d819eb6f9b8677f57f9664265d0a10dd6551d227afb4af2b9cd7bdc2ccbf18"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b8f5db71b28b8c404956ddf79575ea77aa8b1538e8b2ef9ec877945b3f46442"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3406b63232fc7e9b8783ab0b765d7c59e7c59ff96759d8ef9632fca27c7ee4"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ecdd78ab768f844c7a1d4a03595038c166b609f6395e25af9b0f3f26ae1230f"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:168f2dfcfdedf611eb285efac1516c8454c8c99caf271dccda8943576b67552e"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa617107a410245b8660028a7483b68e7914304a6d4882b5ff3d2d3eb5948d8c"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:69959bd3167b993e6e710b99051265654133a98f20cec1d9b493b931942e9c16"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:bd96517ef76c8654446fc3db9242d019a1bb5fe8b751ba414765d59f99210b79"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:ab6dd83b970dc97c2d10bc71aa925b84788c7c05de30241b9e96f9b6d9ea3080"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:eec1bb8cdbba2925bedc887bc0609a80e599c75b12d87ae42ac23fd199445654"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6a7095eeec6f89111d03dabfe5883a1fd54da319c94e0fb104ee8f23616b572d"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6f651ebd0b21ec65dfca93aa629610a0dbc13dbc13554f19b0113da2e61a4763"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f422a209d2455c56849442ae42f25dbaaba1c6c3f501d58761c619c7836642ec"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:62f7fdb0d1ed2065451f086519865b4c90aa19aed51081979ecd05a21eb4d1be"}, + {file = "lxml-5.3.0-cp311-cp311-win32.whl", hash = "sha256:c6379f35350b655fd817cd0d6cbeef7f265f3ae5fedb1caae2eb442bbeae9ab9"}, + {file = "lxml-5.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:9c52100e2c2dbb0649b90467935c4b0de5528833c76a35ea1a2691ec9f1ee7a1"}, + {file = "lxml-5.3.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:e99f5507401436fdcc85036a2e7dc2e28d962550afe1cbfc07c40e454256a859"}, + {file = "lxml-5.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:384aacddf2e5813a36495233b64cb96b1949da72bef933918ba5c84e06af8f0e"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:874a216bf6afaf97c263b56371434e47e2c652d215788396f60477540298218f"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65ab5685d56914b9a2a34d67dd5488b83213d680b0c5d10b47f81da5a16b0b0e"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aac0bbd3e8dd2d9c45ceb82249e8bdd3ac99131a32b4d35c8af3cc9db1657179"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b369d3db3c22ed14c75ccd5af429086f166a19627e84a8fdade3f8f31426e52a"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24037349665434f375645fa9d1f5304800cec574d0310f618490c871fd902b3"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:62d172f358f33a26d6b41b28c170c63886742f5b6772a42b59b4f0fa10526cb1"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:c1f794c02903c2824fccce5b20c339a1a14b114e83b306ff11b597c5f71a1c8d"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:5d6a6972b93c426ace71e0be9a6f4b2cfae9b1baed2eed2006076a746692288c"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:3879cc6ce938ff4eb4900d901ed63555c778731a96365e53fadb36437a131a99"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:74068c601baff6ff021c70f0935b0c7bc528baa8ea210c202e03757c68c5a4ff"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ecd4ad8453ac17bc7ba3868371bffb46f628161ad0eefbd0a855d2c8c32dd81a"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7e2f58095acc211eb9d8b5771bf04df9ff37d6b87618d1cbf85f92399c98dae8"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e63601ad5cd8f860aa99d109889b5ac34de571c7ee902d6812d5d9ddcc77fa7d"}, + {file = "lxml-5.3.0-cp312-cp312-win32.whl", hash = "sha256:17e8d968d04a37c50ad9c456a286b525d78c4a1c15dd53aa46c1d8e06bf6fa30"}, + {file = "lxml-5.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:c1a69e58a6bb2de65902051d57fde951febad631a20a64572677a1052690482f"}, + {file = "lxml-5.3.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c72e9563347c7395910de6a3100a4840a75a6f60e05af5e58566868d5eb2d6a"}, + {file = "lxml-5.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e92ce66cd919d18d14b3856906a61d3f6b6a8500e0794142338da644260595cd"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d04f064bebdfef9240478f7a779e8c5dc32b8b7b0b2fc6a62e39b928d428e51"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c2fb570d7823c2bbaf8b419ba6e5662137f8166e364a8b2b91051a1fb40ab8b"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0c120f43553ec759f8de1fee2f4794452b0946773299d44c36bfe18e83caf002"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:562e7494778a69086f0312ec9689f6b6ac1c6b65670ed7d0267e49f57ffa08c4"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:423b121f7e6fa514ba0c7918e56955a1d4470ed35faa03e3d9f0e3baa4c7e492"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c00f323cc00576df6165cc9d21a4c21285fa6b9989c5c39830c3903dc4303ef3"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_ppc64le.whl", hash = "sha256:1fdc9fae8dd4c763e8a31e7630afef517eab9f5d5d31a278df087f307bf601f4"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_s390x.whl", hash = "sha256:658f2aa69d31e09699705949b5fc4719cbecbd4a97f9656a232e7d6c7be1a367"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:1473427aff3d66a3fa2199004c3e601e6c4500ab86696edffdbc84954c72d832"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a87de7dd873bf9a792bf1e58b1c3887b9264036629a5bf2d2e6579fe8e73edff"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0d7b36afa46c97875303a94e8f3ad932bf78bace9e18e603f2085b652422edcd"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:cf120cce539453ae086eacc0130a324e7026113510efa83ab42ef3fcfccac7fb"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:df5c7333167b9674aa8ae1d4008fa4bc17a313cc490b2cca27838bbdcc6bb15b"}, + {file = "lxml-5.3.0-cp313-cp313-win32.whl", hash = "sha256:c802e1c2ed9f0c06a65bc4ed0189d000ada8049312cfeab6ca635e39c9608957"}, + {file = "lxml-5.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:406246b96d552e0503e17a1006fd27edac678b3fcc9f1be71a2f94b4ff61528d"}, + {file = "lxml-5.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:8f0de2d390af441fe8b2c12626d103540b5d850d585b18fcada58d972b74a74e"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1afe0a8c353746e610bd9031a630a95bcfb1a720684c3f2b36c4710a0a96528f"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56b9861a71575f5795bde89256e7467ece3d339c9b43141dbdd54544566b3b94"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:9fb81d2824dff4f2e297a276297e9031f46d2682cafc484f49de182aa5e5df99"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2c226a06ecb8cdef28845ae976da407917542c5e6e75dcac7cc33eb04aaeb237"}, + {file = "lxml-5.3.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:7d3d1ca42870cdb6d0d29939630dbe48fa511c203724820fc0fd507b2fb46577"}, + {file = "lxml-5.3.0-cp36-cp36m-win32.whl", hash = "sha256:094cb601ba9f55296774c2d57ad68730daa0b13dc260e1f941b4d13678239e70"}, + {file = "lxml-5.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:eafa2c8658f4e560b098fe9fc54539f86528651f61849b22111a9b107d18910c"}, + {file = "lxml-5.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cb83f8a875b3d9b458cada4f880fa498646874ba4011dc974e071a0a84a1b033"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:25f1b69d41656b05885aa185f5fdf822cb01a586d1b32739633679699f220391"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23e0553b8055600b3bf4a00b255ec5c92e1e4aebf8c2c09334f8368e8bd174d6"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ada35dd21dc6c039259596b358caab6b13f4db4d4a7f8665764d616daf9cc1d"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:81b4e48da4c69313192d8c8d4311e5d818b8be1afe68ee20f6385d0e96fc9512"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:2bc9fd5ca4729af796f9f59cd8ff160fe06a474da40aca03fcc79655ddee1a8b"}, + {file = "lxml-5.3.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07da23d7ee08577760f0a71d67a861019103e4812c87e2fab26b039054594cc5"}, + {file = "lxml-5.3.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:ea2e2f6f801696ad7de8aec061044d6c8c0dd4037608c7cab38a9a4d316bfb11"}, + {file = "lxml-5.3.0-cp37-cp37m-win32.whl", hash = "sha256:5c54afdcbb0182d06836cc3d1be921e540be3ebdf8b8a51ee3ef987537455f84"}, + {file = "lxml-5.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:f2901429da1e645ce548bf9171784c0f74f0718c3f6150ce166be39e4dd66c3e"}, + {file = "lxml-5.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c56a1d43b2f9ee4786e4658c7903f05da35b923fb53c11025712562d5cc02753"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ee8c39582d2652dcd516d1b879451500f8db3fe3607ce45d7c5957ab2596040"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdf3a3059611f7585a78ee10399a15566356116a4288380921a4b598d807a22"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:146173654d79eb1fc97498b4280c1d3e1e5d58c398fa530905c9ea50ea849b22"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0a7056921edbdd7560746f4221dca89bb7a3fe457d3d74267995253f46343f15"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:9e4b47ac0f5e749cfc618efdf4726269441014ae1d5583e047b452a32e221920"}, + {file = "lxml-5.3.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f914c03e6a31deb632e2daa881fe198461f4d06e57ac3d0e05bbcab8eae01945"}, + {file = "lxml-5.3.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:213261f168c5e1d9b7535a67e68b1f59f92398dd17a56d934550837143f79c42"}, + {file = "lxml-5.3.0-cp38-cp38-win32.whl", hash = "sha256:218c1b2e17a710e363855594230f44060e2025b05c80d1f0661258142b2add2e"}, + {file = "lxml-5.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:315f9542011b2c4e1d280e4a20ddcca1761993dda3afc7a73b01235f8641e903"}, + {file = "lxml-5.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1ffc23010330c2ab67fac02781df60998ca8fe759e8efde6f8b756a20599c5de"}, + {file = "lxml-5.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2b3778cb38212f52fac9fe913017deea2fdf4eb1a4f8e4cfc6b009a13a6d3fcc"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b0c7a688944891086ba192e21c5229dea54382f4836a209ff8d0a660fac06be"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:747a3d3e98e24597981ca0be0fd922aebd471fa99d0043a3842d00cdcad7ad6a"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86a6b24b19eaebc448dc56b87c4865527855145d851f9fc3891673ff97950540"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b11a5d918a6216e521c715b02749240fb07ae5a1fefd4b7bf12f833bc8b4fe70"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b87753c784d6acb8a25b05cb526c3406913c9d988d51f80adecc2b0775d6aa"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:109fa6fede314cc50eed29e6e56c540075e63d922455346f11e4d7a036d2b8cf"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:02ced472497b8362c8e902ade23e3300479f4f43e45f4105c85ef43b8db85229"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:6b038cc86b285e4f9fea2ba5ee76e89f21ed1ea898e287dc277a25884f3a7dfe"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:7437237c6a66b7ca341e868cda48be24b8701862757426852c9b3186de1da8a2"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7f41026c1d64043a36fda21d64c5026762d53a77043e73e94b71f0521939cc71"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:482c2f67761868f0108b1743098640fbb2a28a8e15bf3f47ada9fa59d9fe08c3"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:1483fd3358963cc5c1c9b122c80606a3a79ee0875bcac0204149fa09d6ff2727"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2dec2d1130a9cda5b904696cec33b2cfb451304ba9081eeda7f90f724097300a"}, + {file = "lxml-5.3.0-cp39-cp39-win32.whl", hash = "sha256:a0eabd0a81625049c5df745209dc7fcef6e2aea7793e5f003ba363610aa0a3ff"}, + {file = "lxml-5.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:89e043f1d9d341c52bf2af6d02e6adde62e0a46e6755d5eb60dc6e4f0b8aeca2"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7b1cd427cb0d5f7393c31b7496419da594fe600e6fdc4b105a54f82405e6626c"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51806cfe0279e06ed8500ce19479d757db42a30fd509940b1701be9c86a5ff9a"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee70d08fd60c9565ba8190f41a46a54096afa0eeb8f76bd66f2c25d3b1b83005"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:8dc2c0395bea8254d8daebc76dcf8eb3a95ec2a46fa6fae5eaccee366bfe02ce"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6ba0d3dcac281aad8a0e5b14c7ed6f9fa89c8612b47939fc94f80b16e2e9bc83"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6e91cf736959057f7aac7adfc83481e03615a8e8dd5758aa1d95ea69e8931dba"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:94d6c3782907b5e40e21cadf94b13b0842ac421192f26b84c45f13f3c9d5dc27"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c300306673aa0f3ed5ed9372b21867690a17dba38c68c44b287437c362ce486b"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78d9b952e07aed35fe2e1a7ad26e929595412db48535921c5013edc8aa4a35ce"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:01220dca0d066d1349bd6a1726856a78f7929f3878f7e2ee83c296c69495309e"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2d9b8d9177afaef80c53c0a9e30fa252ff3036fb1c6494d427c066a4ce6a282f"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:20094fc3f21ea0a8669dc4c61ed7fa8263bd37d97d93b90f28fc613371e7a875"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ace2c2326a319a0bb8a8b0e5b570c764962e95818de9f259ce814ee666603f19"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92e67a0be1639c251d21e35fe74df6bcc40cba445c2cda7c4a967656733249e2"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd5350b55f9fecddc51385463a4f67a5da829bc741e38cf689f38ec9023f54ab"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c1fefd7e3d00921c44dc9ca80a775af49698bbfd92ea84498e56acffd4c5469"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:71a8dd38fbd2f2319136d4ae855a7078c69c9a38ae06e0c17c73fd70fc6caad8"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:97acf1e1fd66ab53dacd2c35b319d7e548380c2e9e8c54525c6e76d21b1ae3b1"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:68934b242c51eb02907c5b81d138cb977b2129a0a75a8f8b60b01cb8586c7b21"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b710bc2b8292966b23a6a0121f7a6c51d45d2347edcc75f016ac123b8054d3f2"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18feb4b93302091b1541221196a2155aa296c363fd233814fa11e181adebc52f"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3eb44520c4724c2e1a57c0af33a379eee41792595023f367ba3952a2d96c2aab"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:609251a0ca4770e5a8768ff902aa02bf636339c5a93f9349b48eb1f606f7f3e9"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:516f491c834eb320d6c843156440fe7fc0d50b33e44387fcec5b02f0bc118a4c"}, + {file = "lxml-5.3.0.tar.gz", hash = "sha256:4e109ca30d1edec1ac60cdbe341905dc3b8f55b16855e03a54aaf59e51ec8c6f"}, ] [package.extras] cssselect = ["cssselect (>=0.7)"] +html-clean = ["lxml-html-clean"] html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] -source = ["Cython (>=3.0.7)"] +source = ["Cython (>=3.0.11)"] [[package]] name = "lz4" @@ -4502,13 +4725,13 @@ files = [ [[package]] name = "marshmallow" -version = "3.21.3" +version = "3.22.0" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.8" files = [ - {file = "marshmallow-3.21.3-py3-none-any.whl", hash = "sha256:86ce7fb914aa865001a4b2092c4c2872d13bc347f3d42673272cabfdbad386f1"}, - {file = "marshmallow-3.21.3.tar.gz", hash = "sha256:4f57c5e050a54d66361e826f94fba213eb10b67b2fdb02c3e0343ce207ba1662"}, + {file = "marshmallow-3.22.0-py3-none-any.whl", hash = "sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9"}, + {file = "marshmallow-3.22.0.tar.gz", hash = "sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e"}, ] [package.dependencies] @@ -4516,7 +4739,7 @@ packaging = ">=17.0" [package.extras] dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] -docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] +docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.0.2)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] tests = ["pytest", "pytz", "simplejson"] [[package]] @@ -4580,15 +4803,15 @@ files = [ [[package]] name = "milvus-lite" -version = "2.4.8" +version = "2.4.9" description = "A lightweight version of Milvus wrapped with Python." optional = false python-versions = ">=3.7" files = [ - {file = "milvus_lite-2.4.8-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b7e90b34b214884cd44cdc112ab243d4cb197b775498355e2437b6cafea025fe"}, - {file = "milvus_lite-2.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:519dfc62709d8f642d98a1c5b1dcde7080d107e6e312d677fef5a3412a40ac08"}, - {file = "milvus_lite-2.4.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b21f36d24cbb0e920b4faad607019bb28c1b2c88b4d04680ac8c7697a4ae8a4d"}, - {file = "milvus_lite-2.4.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:08332a2b9abfe7c4e1d7926068937e46f8fb81f2707928b7bc02c9dc99cebe41"}, + {file = "milvus_lite-2.4.9-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d3e617b3d68c09ad656d54bc3d8cc4ef6ef56c54015e1563d4fe4bcec6b7c90a"}, + {file = "milvus_lite-2.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6e7029282d6829b277ebb92f64e2370be72b938e34770e1eb649346bda5d1d7f"}, + {file = "milvus_lite-2.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9b8e991e4e433596f6a399a165c1a506f823ec9133332e03d7f8a114bff4550d"}, + {file = "milvus_lite-2.4.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:7f53e674602101cfbcf0a4a59d19eaa139dfd5580639f3040ad73d901f24fc0b"}, ] [package.dependencies] @@ -4960,13 +5183,13 @@ twitter = ["twython"] [[package]] name = "novita-client" -version = "0.5.6" +version = "0.5.7" description = "novita SDK for Python" optional = false python-versions = ">=3.6" files = [ - {file = "novita_client-0.5.6-py3-none-any.whl", hash = "sha256:9fa6cfd12f13a75c7da42b27f811a560b0320da24cf256480f517bde479bc57c"}, - {file = "novita_client-0.5.6.tar.gz", hash = "sha256:2e4d956903d5da39d43127a41dcb020ae40322d2a6196413071b94b3d6988b98"}, + {file = "novita_client-0.5.7-py3-none-any.whl", hash = "sha256:844a4c09c98328c8d4f72e1d3f63f76285c2963dcc37ccb2de41cbfdbe7fa51d"}, + {file = "novita_client-0.5.7.tar.gz", hash = "sha256:65baf748757aafd8ab080a64f9ab069a40c0810fc1fa9be9c26596988a0aa4b4"}, ] [package.dependencies] @@ -5139,42 +5362,42 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "onnxruntime" -version = "1.18.1" +version = "1.19.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = false python-versions = "*" files = [ - {file = "onnxruntime-1.18.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:29ef7683312393d4ba04252f1b287d964bd67d5e6048b94d2da3643986c74d80"}, - {file = "onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc706eb1df06ddf55776e15a30519fb15dda7697f987a2bbda4962845e3cec05"}, - {file = "onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7de69f5ced2a263531923fa68bbec52a56e793b802fcd81a03487b5e292bc3a"}, - {file = "onnxruntime-1.18.1-cp310-cp310-win32.whl", hash = "sha256:221e5b16173926e6c7de2cd437764492aa12b6811f45abd37024e7cf2ae5d7e3"}, - {file = "onnxruntime-1.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:75211b619275199c861ee94d317243b8a0fcde6032e5a80e1aa9ded8ab4c6060"}, - {file = "onnxruntime-1.18.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:f26582882f2dc581b809cfa41a125ba71ad9e715738ec6402418df356969774a"}, - {file = "onnxruntime-1.18.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef36f3a8b768506d02be349ac303fd95d92813ba3ba70304d40c3cd5c25d6a4c"}, - {file = "onnxruntime-1.18.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:170e711393e0618efa8ed27b59b9de0ee2383bd2a1f93622a97006a5ad48e434"}, - {file = "onnxruntime-1.18.1-cp311-cp311-win32.whl", hash = "sha256:9b6a33419b6949ea34e0dc009bc4470e550155b6da644571ecace4b198b0d88f"}, - {file = "onnxruntime-1.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:5c1380a9f1b7788da742c759b6a02ba771fe1ce620519b2b07309decbd1a2fe1"}, - {file = "onnxruntime-1.18.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:31bd57a55e3f983b598675dfc7e5d6f0877b70ec9864b3cc3c3e1923d0a01919"}, - {file = "onnxruntime-1.18.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9e03c4ba9f734500691a4d7d5b381cd71ee2f3ce80a1154ac8f7aed99d1ecaa"}, - {file = "onnxruntime-1.18.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:781aa9873640f5df24524f96f6070b8c550c66cb6af35710fd9f92a20b4bfbf6"}, - {file = "onnxruntime-1.18.1-cp312-cp312-win32.whl", hash = "sha256:3a2d9ab6254ca62adbb448222e630dc6883210f718065063518c8f93a32432be"}, - {file = "onnxruntime-1.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ad93c560b1c38c27c0275ffd15cd7f45b3ad3fc96653c09ce2931179982ff204"}, - {file = "onnxruntime-1.18.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:3b55dc9d3c67626388958a3eb7ad87eb7c70f75cb0f7ff4908d27b8b42f2475c"}, - {file = "onnxruntime-1.18.1-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f80dbcfb6763cc0177a31168b29b4bd7662545b99a19e211de8c734b657e0669"}, - {file = "onnxruntime-1.18.1-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1ff2c61a16d6c8631796c54139bafea41ee7736077a0fc64ee8ae59432f5c58"}, - {file = "onnxruntime-1.18.1-cp38-cp38-win32.whl", hash = "sha256:219855bd272fe0c667b850bf1a1a5a02499269a70d59c48e6f27f9c8bcb25d02"}, - {file = "onnxruntime-1.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:afdf16aa607eb9a2c60d5ca2d5abf9f448e90c345b6b94c3ed14f4fb7e6a2d07"}, - {file = "onnxruntime-1.18.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:128df253ade673e60cea0955ec9d0e89617443a6d9ce47c2d79eb3f72a3be3de"}, - {file = "onnxruntime-1.18.1-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9839491e77e5c5a175cab3621e184d5a88925ee297ff4c311b68897197f4cde9"}, - {file = "onnxruntime-1.18.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad3187c1faff3ac15f7f0e7373ef4788c582cafa655a80fdbb33eaec88976c66"}, - {file = "onnxruntime-1.18.1-cp39-cp39-win32.whl", hash = "sha256:34657c78aa4e0b5145f9188b550ded3af626651b15017bf43d280d7e23dbf195"}, - {file = "onnxruntime-1.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c14fd97c3ddfa97da5feef595e2c73f14c2d0ec1d4ecbea99c8d96603c89589"}, + {file = "onnxruntime-1.19.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6ce22a98dfec7b646ae305f52d0ce14a189a758b02ea501860ca719f4b0ae04b"}, + {file = "onnxruntime-1.19.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:19019c72873f26927aa322c54cf2bf7312b23451b27451f39b88f57016c94f8b"}, + {file = "onnxruntime-1.19.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8eaa16df99171dc636e30108d15597aed8c4c2dd9dbfdd07cc464d57d73fb275"}, + {file = "onnxruntime-1.19.0-cp310-cp310-win32.whl", hash = "sha256:0eb0f8dbe596fd0f4737fe511fdbb17603853a7d204c5b2ca38d3c7808fc556b"}, + {file = "onnxruntime-1.19.0-cp310-cp310-win_amd64.whl", hash = "sha256:616092d54ba8023b7bc0a5f6d900a07a37cc1cfcc631873c15f8c1d6e9e184d4"}, + {file = "onnxruntime-1.19.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a2b53b3c287cd933e5eb597273926e899082d8c84ab96e1b34035764a1627e17"}, + {file = "onnxruntime-1.19.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e94984663963e74fbb468bde9ec6f19dcf890b594b35e249c4dc8789d08993c5"}, + {file = "onnxruntime-1.19.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f379d1f050cfb55ce015d53727b78ee362febc065c38eed81512b22b757da73"}, + {file = "onnxruntime-1.19.0-cp311-cp311-win32.whl", hash = "sha256:4ccb48faea02503275ae7e79e351434fc43c294c4cb5c4d8bcb7479061396614"}, + {file = "onnxruntime-1.19.0-cp311-cp311-win_amd64.whl", hash = "sha256:9cdc8d311289a84e77722de68bd22b8adfb94eea26f4be6f9e017350faac8b18"}, + {file = "onnxruntime-1.19.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:1b59eaec1be9a8613c5fdeaafe67f73a062edce3ac03bbbdc9e2d98b58a30617"}, + {file = "onnxruntime-1.19.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be4144d014a4b25184e63ce7a463a2e7796e2f3df931fccc6a6aefa6f1365dc5"}, + {file = "onnxruntime-1.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10d7e7d4ca7021ce7f29a66dbc6071addf2de5839135339bd855c6d9c2bba371"}, + {file = "onnxruntime-1.19.0-cp312-cp312-win32.whl", hash = "sha256:87f2c58b577a1fb31dc5d92b647ecc588fd5f1ea0c3ad4526f5f80a113357c8d"}, + {file = "onnxruntime-1.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:8a1f50d49676d7b69566536ff039d9e4e95fc482a55673719f46528218ecbb94"}, + {file = "onnxruntime-1.19.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:71423c8c4b2d7a58956271534302ec72721c62a41efd0c4896343249b8399ab0"}, + {file = "onnxruntime-1.19.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9d63630d45e9498f96e75bbeb7fd4a56acb10155de0de4d0e18d1b6cbb0b358a"}, + {file = "onnxruntime-1.19.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3bfd15db1e8794d379a86c1a9116889f47f2cca40cc82208fc4f7e8c38e8522"}, + {file = "onnxruntime-1.19.0-cp38-cp38-win32.whl", hash = "sha256:3b098003b6b4cb37cc84942e5f1fe27f945dd857cbd2829c824c26b0ba4a247e"}, + {file = "onnxruntime-1.19.0-cp38-cp38-win_amd64.whl", hash = "sha256:cea067a6541d6787d903ee6843401c5b1332a266585160d9700f9f0939443886"}, + {file = "onnxruntime-1.19.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:c4fcff12dc5ca963c5f76b9822bb404578fa4a98c281e8c666b429192799a099"}, + {file = "onnxruntime-1.19.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6dcad8a4db908fbe70b98c79cea1c8b6ac3316adf4ce93453136e33a524ac59"}, + {file = "onnxruntime-1.19.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bc449907c6e8d99eee5ae5cc9c8fdef273d801dcd195393d3f9ab8ad3f49522"}, + {file = "onnxruntime-1.19.0-cp39-cp39-win32.whl", hash = "sha256:947febd48405afcf526e45ccff97ff23b15e530434705f734870d22ae7fcf236"}, + {file = "onnxruntime-1.19.0-cp39-cp39-win_amd64.whl", hash = "sha256:f60be47eff5ee77fd28a466b0fd41d7debc42a32179d1ddb21e05d6067d7b48b"}, ] [package.dependencies] coloredlogs = "*" flatbuffers = "*" -numpy = ">=1.21.6,<2.0" +numpy = ">=1.21.6" packaging = "*" protobuf = "*" sympy = "*" @@ -5202,6 +5425,65 @@ typing-extensions = ">=4.7,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "opencensus" +version = "0.11.4" +description = "A stats collection and distributed tracing framework" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-0.11.4-py2.py3-none-any.whl", hash = "sha256:a18487ce68bc19900336e0ff4655c5a116daf10c1b3685ece8d971bddad6a864"}, + {file = "opencensus-0.11.4.tar.gz", hash = "sha256:cbef87d8b8773064ab60e5c2a1ced58bbaa38a6d052c41aec224958ce544eff2"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.0.0,<3.0.0", markers = "python_version >= \"3.6\""} +opencensus-context = ">=0.1.3" +six = ">=1.16,<2.0" + +[[package]] +name = "opencensus-context" +version = "0.1.3" +description = "OpenCensus Runtime Context" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-context-0.1.3.tar.gz", hash = "sha256:a03108c3c10d8c80bb5ddf5c8a1f033161fa61972a9917f9b9b3a18517f0088c"}, + {file = "opencensus_context-0.1.3-py2.py3-none-any.whl", hash = "sha256:073bb0590007af276853009fac7e4bab1d523c3f03baf4cb4511ca38967c6039"}, +] + +[[package]] +name = "opencensus-ext-azure" +version = "1.1.13" +description = "OpenCensus Azure Monitor Exporter" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-ext-azure-1.1.13.tar.gz", hash = "sha256:aec30472177005379ba56a702a097d618c5f57558e1bb6676ec75f948130692a"}, + {file = "opencensus_ext_azure-1.1.13-py2.py3-none-any.whl", hash = "sha256:06001fac6f8588ba00726a3a7c6c7f2fc88bc8ad12a65afdca657923085393dd"}, +] + +[package.dependencies] +azure-core = ">=1.12.0,<2.0.0" +azure-identity = ">=1.5.0,<2.0.0" +opencensus = ">=0.11.4,<1.0.0" +psutil = ">=5.6.3" +requests = ">=2.19.0" + +[[package]] +name = "opencensus-ext-logging" +version = "0.1.1" +description = "OpenCensus logging Integration" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-ext-logging-0.1.1.tar.gz", hash = "sha256:c203b70f034151dada529f543af330ba17aaffec27d8a5267d03c713eb1de334"}, + {file = "opencensus_ext_logging-0.1.1-py2.py3-none-any.whl", hash = "sha256:cfdaf5da5d8b195ff3d1af87a4066a6621a28046173f6be4b0b6caec4a3ca89f"}, +] + +[package.dependencies] +opencensus = ">=0.8.0,<1.0.0" + [[package]] name = "openpyxl" version = "3.1.5" @@ -5242,42 +5524,42 @@ kerberos = ["requests-kerberos"] [[package]] name = "opentelemetry-api" -version = "1.25.0" +version = "1.26.0" description = "OpenTelemetry Python API" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_api-1.25.0-py3-none-any.whl", hash = "sha256:757fa1aa020a0f8fa139f8959e53dec2051cc26b832e76fa839a6d76ecefd737"}, - {file = "opentelemetry_api-1.25.0.tar.gz", hash = "sha256:77c4985f62f2614e42ce77ee4c9da5fa5f0bc1e1821085e9a47533a9323ae869"}, + {file = "opentelemetry_api-1.26.0-py3-none-any.whl", hash = "sha256:7d7ea33adf2ceda2dd680b18b1677e4152000b37ca76e679da71ff103b943064"}, + {file = "opentelemetry_api-1.26.0.tar.gz", hash = "sha256:2bd639e4bed5b18486fef0b5a520aaffde5a18fc225e808a1ac4df363f43a1ce"}, ] [package.dependencies] deprecated = ">=1.2.6" -importlib-metadata = ">=6.0,<=7.1" +importlib-metadata = ">=6.0,<=8.0.0" [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.25.0" +version = "1.26.0" description = "OpenTelemetry Protobuf encoding" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_common-1.25.0-py3-none-any.whl", hash = "sha256:15637b7d580c2675f70246563363775b4e6de947871e01d0f4e3881d1848d693"}, - {file = "opentelemetry_exporter_otlp_proto_common-1.25.0.tar.gz", hash = "sha256:c93f4e30da4eee02bacd1e004eb82ce4da143a2f8e15b987a9f603e0a85407d3"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.26.0-py3-none-any.whl", hash = "sha256:ee4d8f8891a1b9c372abf8d109409e5b81947cf66423fd998e56880057afbc71"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.26.0.tar.gz", hash = "sha256:bdbe50e2e22a1c71acaa0c8ba6efaadd58882e5a5978737a44a4c4b10d304c92"}, ] [package.dependencies] -opentelemetry-proto = "1.25.0" +opentelemetry-proto = "1.26.0" [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.25.0" +version = "1.26.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0-py3-none-any.whl", hash = "sha256:3131028f0c0a155a64c430ca600fd658e8e37043cb13209f0109db5c1a3e4eb4"}, - {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0.tar.gz", hash = "sha256:c0b1661415acec5af87625587efa1ccab68b873745ca0ee96b69bb1042087eac"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.26.0-py3-none-any.whl", hash = "sha256:e2be5eff72ebcb010675b818e8d7c2e7d61ec451755b8de67a140bc49b9b0280"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.26.0.tar.gz", hash = "sha256:a65b67a9a6b06ba1ec406114568e21afe88c1cdb29c464f2507d529eb906d8ae"}, ] [package.dependencies] @@ -5285,19 +5567,19 @@ deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" grpcio = ">=1.0.0,<2.0.0" opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.25.0" -opentelemetry-proto = "1.25.0" -opentelemetry-sdk = ">=1.25.0,<1.26.0" +opentelemetry-exporter-otlp-proto-common = "1.26.0" +opentelemetry-proto = "1.26.0" +opentelemetry-sdk = ">=1.26.0,<1.27.0" [[package]] name = "opentelemetry-instrumentation" -version = "0.46b0" +version = "0.47b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation-0.46b0-py3-none-any.whl", hash = "sha256:89cd721b9c18c014ca848ccd11181e6b3fd3f6c7669e35d59c48dc527408c18b"}, - {file = "opentelemetry_instrumentation-0.46b0.tar.gz", hash = "sha256:974e0888fb2a1e01c38fbacc9483d024bb1132aad92d6d24e2e5543887a7adda"}, + {file = "opentelemetry_instrumentation-0.47b0-py3-none-any.whl", hash = "sha256:88974ee52b1db08fc298334b51c19d47e53099c33740e48c4f084bd1afd052d5"}, + {file = "opentelemetry_instrumentation-0.47b0.tar.gz", hash = "sha256:96f9885e450c35e3f16a4f33145f2ebf620aea910c9fd74a392bbc0f807a350f"}, ] [package.dependencies] @@ -5307,55 +5589,55 @@ wrapt = ">=1.0.0,<2.0.0" [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.46b0" +version = "0.47b0" description = "ASGI instrumentation for OpenTelemetry" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_asgi-0.46b0-py3-none-any.whl", hash = "sha256:f13c55c852689573057837a9500aeeffc010c4ba59933c322e8f866573374759"}, - {file = "opentelemetry_instrumentation_asgi-0.46b0.tar.gz", hash = "sha256:02559f30cf4b7e2a737ab17eb52aa0779bcf4cc06573064f3e2cb4dcc7d3040a"}, + {file = "opentelemetry_instrumentation_asgi-0.47b0-py3-none-any.whl", hash = "sha256:b798dc4957b3edc9dfecb47a4c05809036a4b762234c5071212fda39ead80ade"}, + {file = "opentelemetry_instrumentation_asgi-0.47b0.tar.gz", hash = "sha256:e78b7822c1bca0511e5e9610ec484b8994a81670375e570c76f06f69af7c506a"}, ] [package.dependencies] asgiref = ">=3.0,<4.0" opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.46b0" -opentelemetry-semantic-conventions = "0.46b0" -opentelemetry-util-http = "0.46b0" +opentelemetry-instrumentation = "0.47b0" +opentelemetry-semantic-conventions = "0.47b0" +opentelemetry-util-http = "0.47b0" [package.extras] instruments = ["asgiref (>=3.0,<4.0)"] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.46b0" +version = "0.47b0" description = "OpenTelemetry FastAPI Instrumentation" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_fastapi-0.46b0-py3-none-any.whl", hash = "sha256:e0f5d150c6c36833dd011f0e6ef5ede6d7406c1aed0c7c98b2d3b38a018d1b33"}, - {file = "opentelemetry_instrumentation_fastapi-0.46b0.tar.gz", hash = "sha256:928a883a36fc89f9702f15edce43d1a7104da93d740281e32d50ffd03dbb4365"}, + {file = "opentelemetry_instrumentation_fastapi-0.47b0-py3-none-any.whl", hash = "sha256:5ac28dd401160b02e4f544a85a9e4f61a8cbe5b077ea0379d411615376a2bd21"}, + {file = "opentelemetry_instrumentation_fastapi-0.47b0.tar.gz", hash = "sha256:0c7c10b5d971e99a420678ffd16c5b1ea4f0db3b31b62faf305fbb03b4ebee36"}, ] [package.dependencies] opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.46b0" -opentelemetry-instrumentation-asgi = "0.46b0" -opentelemetry-semantic-conventions = "0.46b0" -opentelemetry-util-http = "0.46b0" +opentelemetry-instrumentation = "0.47b0" +opentelemetry-instrumentation-asgi = "0.47b0" +opentelemetry-semantic-conventions = "0.47b0" +opentelemetry-util-http = "0.47b0" [package.extras] -instruments = ["fastapi (>=0.58,<1.0)"] +instruments = ["fastapi (>=0.58,<1.0)", "fastapi-slim (>=0.111.0,<0.112.0)"] [[package]] name = "opentelemetry-proto" -version = "1.25.0" +version = "1.26.0" description = "OpenTelemetry Python Proto" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_proto-1.25.0-py3-none-any.whl", hash = "sha256:f07e3341c78d835d9b86665903b199893befa5e98866f63d22b00d0b7ca4972f"}, - {file = "opentelemetry_proto-1.25.0.tar.gz", hash = "sha256:35b6ef9dc4a9f7853ecc5006738ad40443701e52c26099e197895cbda8b815a3"}, + {file = "opentelemetry_proto-1.26.0-py3-none-any.whl", hash = "sha256:6c4d7b4d4d9c88543bcf8c28ae3f8f0448a753dc291c18c5390444c90b76a725"}, + {file = "opentelemetry_proto-1.26.0.tar.gz", hash = "sha256:c5c18796c0cab3751fc3b98dee53855835e90c0422924b484432ac852d93dc1e"}, ] [package.dependencies] @@ -5363,43 +5645,44 @@ protobuf = ">=3.19,<5.0" [[package]] name = "opentelemetry-sdk" -version = "1.25.0" +version = "1.26.0" description = "OpenTelemetry Python SDK" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_sdk-1.25.0-py3-none-any.whl", hash = "sha256:d97ff7ec4b351692e9d5a15af570c693b8715ad78b8aafbec5c7100fe966b4c9"}, - {file = "opentelemetry_sdk-1.25.0.tar.gz", hash = "sha256:ce7fc319c57707ef5bf8b74fb9f8ebdb8bfafbe11898410e0d2a761d08a98ec7"}, + {file = "opentelemetry_sdk-1.26.0-py3-none-any.whl", hash = "sha256:feb5056a84a88670c041ea0ded9921fca559efec03905dddeb3885525e0af897"}, + {file = "opentelemetry_sdk-1.26.0.tar.gz", hash = "sha256:c90d2868f8805619535c05562d699e2f4fb1f00dbd55a86dcefca4da6fa02f85"}, ] [package.dependencies] -opentelemetry-api = "1.25.0" -opentelemetry-semantic-conventions = "0.46b0" +opentelemetry-api = "1.26.0" +opentelemetry-semantic-conventions = "0.47b0" typing-extensions = ">=3.7.4" [[package]] name = "opentelemetry-semantic-conventions" -version = "0.46b0" +version = "0.47b0" description = "OpenTelemetry Semantic Conventions" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_semantic_conventions-0.46b0-py3-none-any.whl", hash = "sha256:6daef4ef9fa51d51855d9f8e0ccd3a1bd59e0e545abe99ac6203804e36ab3e07"}, - {file = "opentelemetry_semantic_conventions-0.46b0.tar.gz", hash = "sha256:fbc982ecbb6a6e90869b15c1673be90bd18c8a56ff1cffc0864e38e2edffaefa"}, + {file = "opentelemetry_semantic_conventions-0.47b0-py3-none-any.whl", hash = "sha256:4ff9d595b85a59c1c1413f02bba320ce7ea6bf9e2ead2b0913c4395c7bbc1063"}, + {file = "opentelemetry_semantic_conventions-0.47b0.tar.gz", hash = "sha256:a8d57999bbe3495ffd4d510de26a97dadc1dace53e0275001b2c1b2f67992a7e"}, ] [package.dependencies] -opentelemetry-api = "1.25.0" +deprecated = ">=1.2.6" +opentelemetry-api = "1.26.0" [[package]] name = "opentelemetry-util-http" -version = "0.46b0" +version = "0.47b0" description = "Web util for OpenTelemetry" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_util_http-0.46b0-py3-none-any.whl", hash = "sha256:8dc1949ce63caef08db84ae977fdc1848fe6dc38e6bbaad0ae3e6ecd0d451629"}, - {file = "opentelemetry_util_http-0.46b0.tar.gz", hash = "sha256:03b6e222642f9c7eae58d9132343e045b50aca9761fcb53709bd2b663571fdf6"}, + {file = "opentelemetry_util_http-0.47b0-py3-none-any.whl", hash = "sha256:3d3215e09c4a723b12da6d0233a31395aeb2bb33a64d7b15a1500690ba250f19"}, + {file = "opentelemetry_util_http-0.47b0.tar.gz", hash = "sha256:352a07664c18eef827eb8ddcbd64c64a7284a39dd1655e2f16f577eb046ccb32"}, ] [[package]] @@ -5447,62 +5730,68 @@ cryptography = ">=3.2.1" [[package]] name = "orjson" -version = "3.10.6" +version = "3.10.7" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.6-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:fb0ee33124db6eaa517d00890fc1a55c3bfe1cf78ba4a8899d71a06f2d6ff5c7"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c1c4b53b24a4c06547ce43e5fee6ec4e0d8fe2d597f4647fc033fd205707365"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eadc8fd310edb4bdbd333374f2c8fec6794bbbae99b592f448d8214a5e4050c0"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61272a5aec2b2661f4fa2b37c907ce9701e821b2c1285d5c3ab0207ebd358d38"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57985ee7e91d6214c837936dc1608f40f330a6b88bb13f5a57ce5257807da143"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:633a3b31d9d7c9f02d49c4ab4d0a86065c4a6f6adc297d63d272e043472acab5"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1c680b269d33ec444afe2bdc647c9eb73166fa47a16d9a75ee56a374f4a45f43"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f759503a97a6ace19e55461395ab0d618b5a117e8d0fbb20e70cfd68a47327f2"}, - {file = "orjson-3.10.6-cp310-none-win32.whl", hash = "sha256:95a0cce17f969fb5391762e5719575217bd10ac5a189d1979442ee54456393f3"}, - {file = "orjson-3.10.6-cp310-none-win_amd64.whl", hash = "sha256:df25d9271270ba2133cc88ee83c318372bdc0f2cd6f32e7a450809a111efc45c"}, - {file = "orjson-3.10.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b1ec490e10d2a77c345def52599311849fc063ae0e67cf4f84528073152bb2ba"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55d43d3feb8f19d07e9f01e5b9be4f28801cf7c60d0fa0d279951b18fae1932b"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3045267e98fe749408eee1593a142e02357c5c99be0802185ef2170086a863"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27bc6a28ae95923350ab382c57113abd38f3928af3c80be6f2ba7eb8d8db0b0"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d27456491ca79532d11e507cadca37fb8c9324a3976294f68fb1eff2dc6ced5a"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05ac3d3916023745aa3b3b388e91b9166be1ca02b7c7e41045da6d12985685f0"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1335d4ef59ab85cab66fe73fd7a4e881c298ee7f63ede918b7faa1b27cbe5212"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4bbc6d0af24c1575edc79994c20e1b29e6fb3c6a570371306db0993ecf144dc5"}, - {file = "orjson-3.10.6-cp311-none-win32.whl", hash = "sha256:450e39ab1f7694465060a0550b3f6d328d20297bf2e06aa947b97c21e5241fbd"}, - {file = "orjson-3.10.6-cp311-none-win_amd64.whl", hash = "sha256:227df19441372610b20e05bdb906e1742ec2ad7a66ac8350dcfd29a63014a83b"}, - {file = "orjson-3.10.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ea2977b21f8d5d9b758bb3f344a75e55ca78e3ff85595d248eee813ae23ecdfb"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6f3d167d13a16ed263b52dbfedff52c962bfd3d270b46b7518365bcc2121eed"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f710f346e4c44a4e8bdf23daa974faede58f83334289df80bc9cd12fe82573c7"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7275664f84e027dcb1ad5200b8b18373e9c669b2a9ec33d410c40f5ccf4b257e"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0943e4c701196b23c240b3d10ed8ecd674f03089198cf503105b474a4f77f21f"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:446dee5a491b5bc7d8f825d80d9637e7af43f86a331207b9c9610e2f93fee22a"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:64c81456d2a050d380786413786b057983892db105516639cb5d3ee3c7fd5148"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, - {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, - {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, - {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2c116072a8533f2fec435fde4d134610f806bdac20188c7bd2081f3e9e0133f"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6eeb13218c8cf34c61912e9df2de2853f1d009de0e46ea09ccdf3d757896af0a"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:965a916373382674e323c957d560b953d81d7a8603fbeee26f7b8248638bd48b"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03c95484d53ed8e479cade8628c9cea00fd9d67f5554764a1110e0d5aa2de96e"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e060748a04cccf1e0a6f2358dffea9c080b849a4a68c28b1b907f272b5127e9b"}, - {file = "orjson-3.10.6-cp38-none-win32.whl", hash = "sha256:738dbe3ef909c4b019d69afc19caf6b5ed0e2f1c786b5d6215fbb7539246e4c6"}, - {file = "orjson-3.10.6-cp38-none-win_amd64.whl", hash = "sha256:d40f839dddf6a7d77114fe6b8a70218556408c71d4d6e29413bb5f150a692ff7"}, - {file = "orjson-3.10.6-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:697a35a083c4f834807a6232b3e62c8b280f7a44ad0b759fd4dce748951e70db"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd502f96bf5ea9a61cbc0b2b5900d0dd68aa0da197179042bdd2be67e51a1e4b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f215789fb1667cdc874c1b8af6a84dc939fd802bf293a8334fce185c79cd359b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2debd8ddce948a8c0938c8c93ade191d2f4ba4649a54302a7da905a81f00b56"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5410111d7b6681d4b0d65e0f58a13be588d01b473822483f77f513c7f93bd3b2"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb1f28a137337fdc18384079fa5726810681055b32b92253fa15ae5656e1dddb"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bf2fbbce5fe7cd1aa177ea3eab2b8e6a6bc6e8592e4279ed3db2d62e57c0e1b2"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:79b9b9e33bd4c517445a62b90ca0cc279b0f1f3970655c3df9e608bc3f91741a"}, - {file = "orjson-3.10.6-cp39-none-win32.whl", hash = "sha256:30b0a09a2014e621b1adf66a4f705f0809358350a757508ee80209b2d8dae219"}, - {file = "orjson-3.10.6-cp39-none-win_amd64.whl", hash = "sha256:49e3bc615652617d463069f91b867a4458114c5b104e13b7ae6872e5f79d0844"}, - {file = "orjson-3.10.6.tar.gz", hash = "sha256:e54b63d0a7c6c54a5f5f726bc93a2078111ef060fec4ecbf34c5db800ca3b3a7"}, + {file = "orjson-3.10.7-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:74f4544f5a6405b90da8ea724d15ac9c36da4d72a738c64685003337401f5c12"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34a566f22c28222b08875b18b0dfbf8a947e69df21a9ed5c51a6bf91cfb944ac"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf6ba8ebc8ef5792e2337fb0419f8009729335bb400ece005606336b7fd7bab7"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac7cf6222b29fbda9e3a472b41e6a5538b48f2c8f99261eecd60aafbdb60690c"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de817e2f5fc75a9e7dd350c4b0f54617b280e26d1631811a43e7e968fa71e3e9"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:348bdd16b32556cf8d7257b17cf2bdb7ab7976af4af41ebe79f9796c218f7e91"}, + {file = "orjson-3.10.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:479fd0844ddc3ca77e0fd99644c7fe2de8e8be1efcd57705b5c92e5186e8a250"}, + {file = "orjson-3.10.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:fdf5197a21dd660cf19dfd2a3ce79574588f8f5e2dbf21bda9ee2d2b46924d84"}, + {file = "orjson-3.10.7-cp310-none-win32.whl", hash = "sha256:d374d36726746c81a49f3ff8daa2898dccab6596864ebe43d50733275c629175"}, + {file = "orjson-3.10.7-cp310-none-win_amd64.whl", hash = "sha256:cb61938aec8b0ffb6eef484d480188a1777e67b05d58e41b435c74b9d84e0b9c"}, + {file = "orjson-3.10.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:7db8539039698ddfb9a524b4dd19508256107568cdad24f3682d5773e60504a2"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:480f455222cb7a1dea35c57a67578848537d2602b46c464472c995297117fa09"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8a9c9b168b3a19e37fe2778c0003359f07822c90fdff8f98d9d2a91b3144d8e0"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8de062de550f63185e4c1c54151bdddfc5625e37daf0aa1e75d2a1293e3b7d9a"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6b0dd04483499d1de9c8f6203f8975caf17a6000b9c0c54630cef02e44ee624e"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b58d3795dafa334fc8fd46f7c5dc013e6ad06fd5b9a4cc98cb1456e7d3558bd6"}, + {file = "orjson-3.10.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:33cfb96c24034a878d83d1a9415799a73dc77480e6c40417e5dda0710d559ee6"}, + {file = "orjson-3.10.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e724cebe1fadc2b23c6f7415bad5ee6239e00a69f30ee423f319c6af70e2a5c0"}, + {file = "orjson-3.10.7-cp311-none-win32.whl", hash = "sha256:82763b46053727a7168d29c772ed5c870fdae2f61aa8a25994c7984a19b1021f"}, + {file = "orjson-3.10.7-cp311-none-win_amd64.whl", hash = "sha256:eb8d384a24778abf29afb8e41d68fdd9a156cf6e5390c04cc07bbc24b89e98b5"}, + {file = "orjson-3.10.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44a96f2d4c3af51bfac6bc4ef7b182aa33f2f054fd7f34cc0ee9a320d051d41f"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ac14cd57df0572453543f8f2575e2d01ae9e790c21f57627803f5e79b0d3c3"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bdbb61dcc365dd9be94e8f7df91975edc9364d6a78c8f7adb69c1cdff318ec93"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b48b3db6bb6e0a08fa8c83b47bc169623f801e5cc4f24442ab2b6617da3b5313"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23820a1563a1d386414fef15c249040042b8e5d07b40ab3fe3efbfbbcbcb8864"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0c6a008e91d10a2564edbb6ee5069a9e66df3fbe11c9a005cb411f441fd2c09"}, + {file = "orjson-3.10.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d352ee8ac1926d6193f602cbe36b1643bbd1bbcb25e3c1a657a4390f3000c9a5"}, + {file = "orjson-3.10.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d2d9f990623f15c0ae7ac608103c33dfe1486d2ed974ac3f40b693bad1a22a7b"}, + {file = "orjson-3.10.7-cp312-none-win32.whl", hash = "sha256:7c4c17f8157bd520cdb7195f75ddbd31671997cbe10aee559c2d613592e7d7eb"}, + {file = "orjson-3.10.7-cp312-none-win_amd64.whl", hash = "sha256:1d9c0e733e02ada3ed6098a10a8ee0052dd55774de3d9110d29868d24b17faa1"}, + {file = "orjson-3.10.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:77d325ed866876c0fa6492598ec01fe30e803272a6e8b10e992288b009cbe149"}, + {file = "orjson-3.10.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ea2c232deedcb605e853ae1db2cc94f7390ac776743b699b50b071b02bea6fe"}, + {file = "orjson-3.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3dcfbede6737fdbef3ce9c37af3fb6142e8e1ebc10336daa05872bfb1d87839c"}, + {file = "orjson-3.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:11748c135f281203f4ee695b7f80bb1358a82a63905f9f0b794769483ea854ad"}, + {file = "orjson-3.10.7-cp313-none-win32.whl", hash = "sha256:a7e19150d215c7a13f39eb787d84db274298d3f83d85463e61d277bbd7f401d2"}, + {file = "orjson-3.10.7-cp313-none-win_amd64.whl", hash = "sha256:eef44224729e9525d5261cc8d28d6b11cafc90e6bd0be2157bde69a52ec83024"}, + {file = "orjson-3.10.7-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6ea2b2258eff652c82652d5e0f02bd5e0463a6a52abb78e49ac288827aaa1469"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:430ee4d85841e1483d487e7b81401785a5dfd69db5de01314538f31f8fbf7ee1"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b6146e439af4c2472c56f8540d799a67a81226e11992008cb47e1267a9b3225"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:084e537806b458911137f76097e53ce7bf5806dda33ddf6aaa66a028f8d43a23"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829cf2195838e3f93b70fd3b4292156fc5e097aac3739859ac0dcc722b27ac0"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1193b2416cbad1a769f868b1749535d5da47626ac29445803dae7cc64b3f5c98"}, + {file = "orjson-3.10.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:4e6c3da13e5a57e4b3dca2de059f243ebec705857522f188f0180ae88badd354"}, + {file = "orjson-3.10.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c31008598424dfbe52ce8c5b47e0752dca918a4fdc4a2a32004efd9fab41d866"}, + {file = "orjson-3.10.7-cp38-none-win32.whl", hash = "sha256:7122a99831f9e7fe977dc45784d3b2edc821c172d545e6420c375e5a935f5a1c"}, + {file = "orjson-3.10.7-cp38-none-win_amd64.whl", hash = "sha256:a763bc0e58504cc803739e7df040685816145a6f3c8a589787084b54ebc9f16e"}, + {file = "orjson-3.10.7-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e76be12658a6fa376fcd331b1ea4e58f5a06fd0220653450f0d415b8fd0fbe20"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed350d6978d28b92939bfeb1a0570c523f6170efc3f0a0ef1f1df287cd4f4960"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:144888c76f8520e39bfa121b31fd637e18d4cc2f115727865fdf9fa325b10412"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09b2d92fd95ad2402188cf51573acde57eb269eddabaa60f69ea0d733e789fe9"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b24a579123fa884f3a3caadaed7b75eb5715ee2b17ab5c66ac97d29b18fe57f"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72591bcfe7512353bd609875ab38050efe3d55e18934e2f18950c108334b4ff"}, + {file = "orjson-3.10.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f4db56635b58cd1a200b0a23744ff44206ee6aa428185e2b6c4a65b3197abdcd"}, + {file = "orjson-3.10.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0fa5886854673222618638c6df7718ea7fe2f3f2384c452c9ccedc70b4a510a5"}, + {file = "orjson-3.10.7-cp39-none-win32.whl", hash = "sha256:8272527d08450ab16eb405f47e0f4ef0e5ff5981c3d82afe0efd25dcbef2bcd2"}, + {file = "orjson-3.10.7-cp39-none-win_amd64.whl", hash = "sha256:974683d4618c0c7dbf4f69c95a979734bf183d0658611760017f6e70a145af58"}, + {file = "orjson-3.10.7.tar.gz", hash = "sha256:75ef0640403f945f3a1f9f6400686560dbfb0fb5b16589ad62cd477043c4eee3"}, ] [[package]] @@ -5847,13 +6136,13 @@ tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "p [[package]] name = "posthog" -version = "3.5.0" +version = "3.5.2" description = "Integrate PostHog into any python application." optional = false python-versions = "*" files = [ - {file = "posthog-3.5.0-py2.py3-none-any.whl", hash = "sha256:3c672be7ba6f95d555ea207d4486c171d06657eb34b3ce25eb043bfe7b6b5b76"}, - {file = "posthog-3.5.0.tar.gz", hash = "sha256:8f7e3b2c6e8714d0c0c542a2109b83a7549f63b7113a133ab2763a89245ef2ef"}, + {file = "posthog-3.5.2-py2.py3-none-any.whl", hash = "sha256:605b3d92369971cc99290b1fcc8534cbddac3726ef7972caa993454a5ecfb644"}, + {file = "posthog-3.5.2.tar.gz", hash = "sha256:a383a80c1f47e0243f5ce359e81e06e2e7b37eb39d1d6f8d01c3e64ed29df2ee"}, ] [package.dependencies] @@ -5868,6 +6157,26 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] +[[package]] +name = "primp" +version = "0.6.1" +description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "primp-0.6.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:60cfe95e0bdf154b0f9036d38acaddc9aef02d6723ed125839b01449672d3946"}, + {file = "primp-0.6.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e1e92433ecf32639f9e800bc3a5d58b03792bdec99421b7fb06500e2fae63c85"}, + {file = "primp-0.6.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e02353f13f07fb5a6f91df9e2f4d8ec9f41312de95088744dce1c9729a3865d"}, + {file = "primp-0.6.1-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c5a2ccfdf488b17be225a529a31e2b22724b2e22fba8e1ae168a222f857c2dc0"}, + {file = "primp-0.6.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f335c2ace907800a23bbb7bc6e15acc7fff659b86a2d5858817f6ed79cea07cf"}, + {file = "primp-0.6.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5dc15bd9d47ded7bc356fcb5d8321972dcbeba18e7d3b7250e12bb7365447b2b"}, + {file = "primp-0.6.1-cp38-abi3-win_amd64.whl", hash = "sha256:eebf0412ebba4089547b16b97b765d83f69f1433d811bb02b02cdcdbca20f672"}, + {file = "primp-0.6.1.tar.gz", hash = "sha256:64b3c12e3d463a887518811c46f3ec37cca02e6af1ddf1287e548342de436301"}, +] + +[package.extras] +dev = ["certifi", "pytest (>=8.1.1)"] + [[package]] name = "prompt-toolkit" version = "3.0.47" @@ -5901,24 +6210,53 @@ testing = ["google-api-core (>=1.31.5)"] [[package]] name = "protobuf" -version = "4.25.3" +version = "4.25.4" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"}, - {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"}, - {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"}, - {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"}, - {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"}, - {file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"}, - {file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"}, - {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"}, - {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"}, - {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"}, - {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, + {file = "protobuf-4.25.4-cp310-abi3-win32.whl", hash = "sha256:db9fd45183e1a67722cafa5c1da3e85c6492a5383f127c86c4c4aa4845867dc4"}, + {file = "protobuf-4.25.4-cp310-abi3-win_amd64.whl", hash = "sha256:ba3d8504116a921af46499471c63a85260c1a5fc23333154a427a310e015d26d"}, + {file = "protobuf-4.25.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:eecd41bfc0e4b1bd3fa7909ed93dd14dd5567b98c941d6c1ad08fdcab3d6884b"}, + {file = "protobuf-4.25.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:4c8a70fdcb995dcf6c8966cfa3a29101916f7225e9afe3ced4395359955d3835"}, + {file = "protobuf-4.25.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3319e073562e2515c6ddc643eb92ce20809f5d8f10fead3332f71c63be6a7040"}, + {file = "protobuf-4.25.4-cp38-cp38-win32.whl", hash = "sha256:7e372cbbda66a63ebca18f8ffaa6948455dfecc4e9c1029312f6c2edcd86c4e1"}, + {file = "protobuf-4.25.4-cp38-cp38-win_amd64.whl", hash = "sha256:051e97ce9fa6067a4546e75cb14f90cf0232dcb3e3d508c448b8d0e4265b61c1"}, + {file = "protobuf-4.25.4-cp39-cp39-win32.whl", hash = "sha256:90bf6fd378494eb698805bbbe7afe6c5d12c8e17fca817a646cd6a1818c696ca"}, + {file = "protobuf-4.25.4-cp39-cp39-win_amd64.whl", hash = "sha256:ac79a48d6b99dfed2729ccccee547b34a1d3d63289c71cef056653a846a2240f"}, + {file = "protobuf-4.25.4-py3-none-any.whl", hash = "sha256:bfbebc1c8e4793cfd58589acfb8a1026be0003e852b9da7db5a4285bde996978"}, + {file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"}, ] +[[package]] +name = "psutil" +version = "6.0.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + [[package]] name = "psycopg2-binary" version = "2.9.9" @@ -6283,13 +6621,13 @@ semver = ["semver (>=3.0.2)"] [[package]] name = "pydantic-settings" -version = "2.3.4" +version = "2.4.0" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.3.4-py3-none-any.whl", hash = "sha256:11ad8bacb68a045f00e4f862c7a718c8a9ec766aa8fd4c32e39a0594b207b53a"}, - {file = "pydantic_settings-2.3.4.tar.gz", hash = "sha256:c5802e3d62b78e82522319bbc9b8f8ffb28ad1c988a99311d04f2a6051fca0a7"}, + {file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"}, + {file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"}, ] [package.dependencies] @@ -6297,20 +6635,27 @@ pydantic = ">=2.7.0" python-dotenv = ">=0.21.0" [package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] toml = ["tomli (>=2.0.1)"] yaml = ["pyyaml (>=6.0.1)"] [[package]] -name = "pydub" -version = "0.25.1" -description = "Manipulate audio with an simple and easy high level interface" +name = "pydash" +version = "8.0.3" +description = "The kitchen sink of Python utility libraries for doing \"stuff\" in a functional way. Based on the Lo-Dash Javascript library." optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"}, - {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"}, + {file = "pydash-8.0.3-py3-none-any.whl", hash = "sha256:c16871476822ee6b59b87e206dd27888240eff50a7b4cd72a4b80b43b6b994d7"}, + {file = "pydash-8.0.3.tar.gz", hash = "sha256:1b27cd3da05b72f0e5ff786c523afd82af796936462e631ffd1b228d91f8b9aa"}, ] +[package.dependencies] +typing-extensions = ">3.10,<4.6.0 || >4.6.0" + +[package.extras] +dev = ["build", "coverage", "furo", "invoke", "mypy", "pytest", "pytest-cov", "pytest-mypy-testing", "ruff", "sphinx", "sphinx-autodoc-typehints", "tox", "twine", "wheel"] + [[package]] name = "pygments" version = "2.18.0" @@ -6347,13 +6692,13 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] name = "pymilvus" -version = "2.4.4" +version = "2.4.5" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.4.4-py3-none-any.whl", hash = "sha256:073b76bc36f6f4e70f0f0a0023a53324f0ba8ef9a60883f87cd30a44b6c6f2b5"}, - {file = "pymilvus-2.4.4.tar.gz", hash = "sha256:50c53eb103e034fbffe936fe942751ea3dbd2452e18cf79acc52360ed4987fb7"}, + {file = "pymilvus-2.4.5-py3-none-any.whl", hash = "sha256:dc4f2d1eac8db9cf3951de39566a1a244695760bb94d8310fbfc73d6d62bb267"}, + {file = "pymilvus-2.4.5.tar.gz", hash = "sha256:1a497fe9b41d6bf62b1d5e1c412960922dde1598576fcbb8818040c8af11149f"}, ] [package.dependencies] @@ -6362,7 +6707,7 @@ grpcio = ">=1.49.1,<=1.63.0" milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" -setuptools = ">=67" +setuptools = ">69" ujson = ">=2.0.0" [package.extras] @@ -6474,35 +6819,15 @@ files = [ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, ] -[[package]] -name = "pyreqwest-impersonate" -version = "0.5.3" -description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f15922496f728769fb9e1b116d5d9d7ba5525d0f2f7a76a41a1daef8b2e0c6c3"}, - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:77533133ae73020e59bc56d776eea3fe3af4ac41d763a89f39c495436da0f4cf"}, - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:436055fa3eeb3e01e2e8efd42a9f6c4ab62fd643eddc7c66d0e671b71605f273"}, - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e9d2e981a525fb72c1521f454e5581d2c7a3b1fcf1c97c0acfcb7a923d8cf3e"}, - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a6bf986d4a165f6976b3e862111e2a46091883cb55e9e6325150f5aea2644229"}, - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b7397f6dad3d5ae158e0b272cb3eafe8382e71775d829b286ae9c21cb5a879ff"}, - {file = "pyreqwest_impersonate-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:6026e4751b5912aec1e45238c07daf1e2c9126b3b32b33396b72885021e8990c"}, - {file = "pyreqwest_impersonate-0.5.3.tar.gz", hash = "sha256:f21c10609958ff5be18df0c329eed42d2b3ba8a339b65dc5f96ab74537231692"}, -] - -[package.extras] -dev = ["pytest (>=8.1.1)"] - [[package]] name = "pytest" -version = "8.1.2" +version = "8.3.2" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.1.2-py3-none-any.whl", hash = "sha256:6c06dc309ff46a05721e6fd48e492a775ed8165d2ecdf57f156a80c7e95bb142"}, - {file = "pytest-8.1.2.tar.gz", hash = "sha256:f3c45d1d5eed96b01a2aea70dee6a4a366d51d38f9957768083e4fecfc77f3ef"}, + {file = "pytest-8.3.2-py3-none-any.whl", hash = "sha256:4ba08f9ae7dcf84ded419494d229b48d0903ea6407b030eaec46df5e6a73bba5"}, + {file = "pytest-8.3.2.tar.gz", hash = "sha256:c132345d12ce551242c87269de812483f5bcc87cdbb4722e48487ba194f9fdce"}, ] [package.dependencies] @@ -6510,11 +6835,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=1.4,<2.0" +pluggy = ">=1.5,<2" tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-benchmark" @@ -6748,20 +7073,6 @@ files = [ {file = "python_magic-0.4.27-py2.py3-none-any.whl", hash = "sha256:c212960ad306f700aa0d01e5d7a325d20548ff97eb9920dcd29513174f0294d3"}, ] -[[package]] -name = "python-multipart" -version = "0.0.9" -description = "A streaming multipart parser for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, - {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, -] - -[package.extras] -dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"] - [[package]] name = "python-pptx" version = "0.6.23" @@ -6825,62 +7136,64 @@ files = [ [[package]] name = "pyyaml" -version = "6.0.1" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] @@ -6951,104 +7264,119 @@ dev = ["pytest"] [[package]] name = "rapidfuzz" -version = "3.9.4" +version = "3.9.6" description = "rapid fuzzy string matching" optional = false python-versions = ">=3.8" files = [ - {file = "rapidfuzz-3.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9b9793c19bdf38656c8eaefbcf4549d798572dadd70581379e666035c9df781"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:015b5080b999404fe06ec2cb4f40b0be62f0710c926ab41e82dfbc28e80675b4"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc5ceca9c1e1663f3e6c23fb89a311f69b7615a40ddd7645e3435bf3082688a"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1424e238bc3f20e1759db1e0afb48a988a9ece183724bef91ea2a291c0b92a95"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed01378f605aa1f449bee82cd9c83772883120d6483e90aa6c5a4ce95dc5c3aa"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb26d412271e5a76cdee1c2d6bf9881310665d3fe43b882d0ed24edfcb891a84"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f37e9e1f17be193c41a31c864ad4cd3ebd2b40780db11cd5c04abf2bcf4201b"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d070ec5cf96b927c4dc5133c598c7ff6db3b833b363b2919b13417f1002560bc"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:10e61bb7bc807968cef09a0e32ce253711a2d450a4dce7841d21d45330ffdb24"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:31a2fc60bb2c7face4140010a7aeeafed18b4f9cdfa495cc644a68a8c60d1ff7"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:fbebf1791a71a2e89f5c12b78abddc018354d5859e305ec3372fdae14f80a826"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:aee9fc9e3bb488d040afc590c0a7904597bf4ccd50d1491c3f4a5e7e67e6cd2c"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-win32.whl", hash = "sha256:005a02688a51c7d2451a2d41c79d737aa326ff54167211b78a383fc2aace2c2c"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:3a2e75e41ee3274754d3b2163cc6c82cd95b892a85ab031f57112e09da36455f"}, - {file = "rapidfuzz-3.9.4-cp310-cp310-win_arm64.whl", hash = "sha256:2c99d355f37f2b289e978e761f2f8efeedc2b14f4751d9ff7ee344a9a5ca98d9"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:07141aa6099e39d48637ce72a25b893fc1e433c50b3e837c75d8edf99e0c63e1"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:db1664eaff5d7d0f2542dd9c25d272478deaf2c8412e4ad93770e2e2d828e175"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc01a223f6605737bec3202e94dcb1a449b6c76d46082cfc4aa980f2a60fd40e"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1869c42e73e2a8910b479be204fa736418741b63ea2325f9cc583c30f2ded41a"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:62ea7007941fb2795fff305ac858f3521ec694c829d5126e8f52a3e92ae75526"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:698e992436bf7f0afc750690c301215a36ff952a6dcd62882ec13b9a1ebf7a39"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b76f611935f15a209d3730c360c56b6df8911a9e81e6a38022efbfb96e433bab"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129627d730db2e11f76169344a032f4e3883d34f20829419916df31d6d1338b1"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:90a82143c14e9a14b723a118c9ef8d1bbc0c5a16b1ac622a1e6c916caff44dd8"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ded58612fe3b0e0d06e935eaeaf5a9fd27da8ba9ed3e2596307f40351923bf72"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f16f5d1c4f02fab18366f2d703391fcdbd87c944ea10736ca1dc3d70d8bd2d8b"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:26aa7eece23e0df55fb75fbc2a8fb678322e07c77d1fd0e9540496e6e2b5f03e"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-win32.whl", hash = "sha256:f187a9c3b940ce1ee324710626daf72c05599946bd6748abe9e289f1daa9a077"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d8e9130fe5d7c9182990b366ad78fd632f744097e753e08ace573877d67c32f8"}, - {file = "rapidfuzz-3.9.4-cp311-cp311-win_arm64.whl", hash = "sha256:40419e98b10cd6a00ce26e4837a67362f658fc3cd7a71bd8bd25c99f7ee8fea5"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b5d5072b548db1b313a07d62d88fe0b037bd2783c16607c647e01b070f6cf9e5"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cf5bcf22e1f0fd273354462631d443ef78d677f7d2fc292de2aec72ae1473e66"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c8fc973adde8ed52810f590410e03fb6f0b541bbaeb04c38d77e63442b2df4c"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2464bb120f135293e9a712e342c43695d3d83168907df05f8c4ead1612310c7"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8d9d58689aca22057cf1a5851677b8a3ccc9b535ca008c7ed06dc6e1899f7844"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:167e745f98baa0f3034c13583e6302fb69249a01239f1483d68c27abb841e0a1"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db0bf0663b4b6da1507869722420ea9356b6195aa907228d6201303e69837af9"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:cd6ac61b74fdb9e23f04d5f068e6cf554f47e77228ca28aa2347a6ca8903972f"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:60ff67c690acecf381759c16cb06c878328fe2361ddf77b25d0e434ea48a29da"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:cb934363380c60f3a57d14af94325125cd8cded9822611a9f78220444034e36e"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:fe833493fb5cc5682c823ea3e2f7066b07612ee8f61ecdf03e1268f262106cdd"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2797fb847d89e04040d281cb1902cbeffbc4b5131a5c53fc0db490fd76b2a547"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-win32.whl", hash = "sha256:52e3d89377744dae68ed7c84ad0ddd3f5e891c82d48d26423b9e066fc835cc7c"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:c76da20481c906e08400ee9be230f9e611d5931a33707d9df40337c2655c84b5"}, - {file = "rapidfuzz-3.9.4-cp312-cp312-win_arm64.whl", hash = "sha256:f2d2846f3980445864c7e8b8818a29707fcaff2f0261159ef6b7bd27ba139296"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:355fc4a268ffa07bab88d9adee173783ec8d20136059e028d2a9135c623c44e6"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4d81a78f90269190b568a8353d4ea86015289c36d7e525cd4d43176c88eff429"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e618625ffc4660b26dc8e56225f8b966d5842fa190e70c60db6cd393e25b86e"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b712336ad6f2bacdbc9f1452556e8942269ef71f60a9e6883ef1726b52d9228a"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc1ee19fdad05770c897e793836c002344524301501d71ef2e832847425707"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1950f8597890c0c707cb7e0416c62a1cf03dcdb0384bc0b2dbda7e05efe738ec"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a6c35f272ec9c430568dc8c1c30cb873f6bc96be2c79795e0bce6db4e0e101d"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:1df0f9e9239132a231c86ae4f545ec2b55409fa44470692fcfb36b1bd00157ad"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:d2c51955329bfccf99ae26f63d5928bf5be9fcfcd9f458f6847fd4b7e2b8986c"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:3c522f462d9fc504f2ea8d82e44aa580e60566acc754422c829ad75c752fbf8d"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:d8a52fc50ded60d81117d7647f262c529659fb21d23e14ebfd0b35efa4f1b83d"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:04dbdfb0f0bfd3f99cf1e9e24fadc6ded2736d7933f32f1151b0f2abb38f9a25"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-win32.whl", hash = "sha256:4968c8bd1df84b42f382549e6226710ad3476f976389839168db3e68fd373298"}, - {file = "rapidfuzz-3.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:3fe4545f89f8d6c27b6bbbabfe40839624873c08bd6700f63ac36970a179f8f5"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9f256c8fb8f3125574c8c0c919ab0a1f75d7cba4d053dda2e762dcc36357969d"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f5fdc09cf6e9d8eac3ce48a4615b3a3ee332ea84ac9657dbbefef913b13e632f"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d395d46b80063d3b5d13c0af43d2c2cedf3ab48c6a0c2aeec715aa5455b0c632"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7fa714fb96ce9e70c37e64c83b62fe8307030081a0bfae74a76fac7ba0f91715"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1bc1a0f29f9119be7a8d3c720f1d2068317ae532e39e4f7f948607c3a6de8396"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6022674aa1747d6300f699cd7c54d7dae89bfe1f84556de699c4ac5df0838082"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcb72e5f9762fd469701a7e12e94b924af9004954f8c739f925cb19c00862e38"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ad04ae301129f0eb5b350a333accd375ce155a0c1cec85ab0ec01f770214e2e4"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f46a22506f17c0433e349f2d1dc11907c393d9b3601b91d4e334fa9a439a6a4d"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:01b42a8728c36011718da409aa86b84984396bf0ca3bfb6e62624f2014f6022c"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e590d5d5443cf56f83a51d3c4867bd1f6be8ef8cfcc44279522bcef3845b2a51"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4c72078b5fdce34ba5753f9299ae304e282420e6455e043ad08e4488ca13a2b0"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-win32.whl", hash = "sha256:f75639277304e9b75e6a7b3c07042d2264e16740a11e449645689ed28e9c2124"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:e81e27e8c32a1e1278a4bb1ce31401bfaa8c2cc697a053b985a6f8d013df83ec"}, - {file = "rapidfuzz-3.9.4-cp39-cp39-win_arm64.whl", hash = "sha256:15bc397ee9a3ed1210b629b9f5f1da809244adc51ce620c504138c6e7095b7bd"}, - {file = "rapidfuzz-3.9.4-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:20488ade4e1ddba3cfad04f400da7a9c1b91eff5b7bd3d1c50b385d78b587f4f"}, - {file = "rapidfuzz-3.9.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:e61b03509b1a6eb31bc5582694f6df837d340535da7eba7bedb8ae42a2fcd0b9"}, - {file = "rapidfuzz-3.9.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:098d231d4e51644d421a641f4a5f2f151f856f53c252b03516e01389b2bfef99"}, - {file = "rapidfuzz-3.9.4-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17ab8b7d10fde8dd763ad428aa961c0f30a1b44426e675186af8903b5d134fb0"}, - {file = "rapidfuzz-3.9.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e272df61bee0a056a3daf99f9b1bd82cf73ace7d668894788139c868fdf37d6f"}, - {file = "rapidfuzz-3.9.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d6481e099ff8c4edda85b8b9b5174c200540fd23c8f38120016c765a86fa01f5"}, - {file = "rapidfuzz-3.9.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ad61676e9bdae677d577fe80ec1c2cea1d150c86be647e652551dcfe505b1113"}, - {file = "rapidfuzz-3.9.4-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:af65020c0dd48d0d8ae405e7e69b9d8ae306eb9b6249ca8bf511a13f465fad85"}, - {file = "rapidfuzz-3.9.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d38b4e026fcd580e0bda6c0ae941e0e9a52c6bc66cdce0b8b0da61e1959f5f8"}, - {file = "rapidfuzz-3.9.4-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f74ed072c2b9dc6743fb19994319d443a4330b0e64aeba0aa9105406c7c5b9c2"}, - {file = "rapidfuzz-3.9.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aee5f6b8321f90615c184bd8a4c676e9becda69b8e4e451a90923db719d6857c"}, - {file = "rapidfuzz-3.9.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:3a555e3c841d6efa350f862204bb0a3fea0c006b8acc9b152b374fa36518a1c6"}, - {file = "rapidfuzz-3.9.4-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0772150d37bf018110351c01d032bf9ab25127b966a29830faa8ad69b7e2f651"}, - {file = "rapidfuzz-3.9.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:addcdd3c3deef1bd54075bd7aba0a6ea9f1d01764a08620074b7a7b1e5447cb9"}, - {file = "rapidfuzz-3.9.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3fe86b82b776554add8f900b6af202b74eb5efe8f25acdb8680a5c977608727f"}, - {file = "rapidfuzz-3.9.4-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0fc91ac59f4414d8542454dfd6287a154b8e6f1256718c898f695bdbb993467"}, - {file = "rapidfuzz-3.9.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a944e546a296a5fdcaabb537b01459f1b14d66f74e584cb2a91448bffadc3c1"}, - {file = "rapidfuzz-3.9.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:4fb96ba96d58c668a17a06b5b5e8340fedc26188e87b0d229d38104556f30cd8"}, - {file = "rapidfuzz-3.9.4.tar.gz", hash = "sha256:366bf8947b84e37f2f4cf31aaf5f37c39f620d8c0eddb8b633e6ba0129ca4a0a"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a7ed0d0b9c85720f0ae33ac5efc8dc3f60c1489dad5c29d735fbdf2f66f0431f"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f3deff6ab7017ed21b9aec5874a07ad13e6b2a688af055837f88b743c7bfd947"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3f9fc060160507b2704f7d1491bd58453d69689b580cbc85289335b14fe8ca"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e86c2b3827fa6169ad6e7d4b790ce02a20acefb8b78d92fa4249589bbc7a2c"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f982e1aafb4bd8207a5e073b1efef9e68a984e91330e1bbf364f9ed157ed83f0"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9196a51d0ec5eaaaf5bca54a85b7b1e666fc944c332f68e6427503af9fb8c49e"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb5a514064e02585b1cc09da2fe406a6dc1a7e5f3e92dd4f27c53e5f1465ec81"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e3a4244f65dbc3580b1275480118c3763f9dc29fc3dd96610560cb5e140a4d4a"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f6ebb910a702e41641e1e1dada3843bc11ba9107a33c98daef6945a885a40a07"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:624fbe96115fb39addafa288d583b5493bc76dab1d34d0ebba9987d6871afdf9"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1c59f1c1507b7a557cf3c410c76e91f097460da7d97e51c985343798e9df7a3c"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f6f0256cb27b6a0fb2e1918477d1b56473cd04acfa245376a342e7c15806a396"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-win32.whl", hash = "sha256:24d473d00d23a30a85802b502b417a7f5126019c3beec91a6739fe7b95388b24"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:248f6d2612e661e2b5f9a22bbd5862a1600e720da7bb6ad8a55bb1548cdfa423"}, + {file = "rapidfuzz-3.9.6-cp310-cp310-win_arm64.whl", hash = "sha256:e03fdf0e74f346ed7e798135df5f2a0fb8d6b96582b00ebef202dcf2171e1d1d"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52e4675f642fbc85632f691b67115a243cd4d2a47bdcc4a3d9a79e784518ff97"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1f93a2f13038700bd245b927c46a2017db3dcd4d4ff94687d74b5123689b873b"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b70500bca460264b8141d8040caee22e9cf0418c5388104ff0c73fb69ee28f"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1e037fb89f714a220f68f902fc6300ab7a33349f3ce8ffae668c3b3a40b0b06"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6792f66d59b86ccfad5e247f2912e255c85c575789acdbad8e7f561412ffed8a"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68d9cffe710b67f1969cf996983608cee4490521d96ea91d16bd7ea5dc80ea98"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daaeeea76da17fa0bbe7fb05cba8ed8064bb1a0edf8360636557f8b6511961"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d214e063bffa13e3b771520b74f674b22d309b5720d4df9918ff3e0c0f037720"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ed443a2062460f44c0346cb9d269b586496b808c2419bbd6057f54061c9b9c75"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5b0c9b227ee0076fb2d58301c505bb837a290ae99ee628beacdb719f0626d749"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:82c9722b7dfaa71e8b61f8c89fed0482567fb69178e139fe4151fc71ed7df782"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c18897c95c0a288347e29537b63608a8f63a5c3cb6da258ac46fcf89155e723e"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-win32.whl", hash = "sha256:3e910cf08944da381159587709daaad9e59d8ff7bca1f788d15928f3c3d49c2a"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:59c4a61fab676d37329fc3a671618a461bfeef53a4d0b8b12e3bc24a14e166f8"}, + {file = "rapidfuzz-3.9.6-cp311-cp311-win_arm64.whl", hash = "sha256:8b4afea244102332973377fddbe54ce844d0916e1c67a5123432291717f32ffa"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:70591b28b218fff351b88cdd7f2359a01a71f9f7f5a2e465ce3715ed4b3c422b"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee2d8355c7343c631a03e57540ea06e8717c19ecf5ff64ea07e0498f7f161457"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:708fb675de0f47b9635d1cc6fbbf80d52cb710d0a1abbfae5c84c46e3abbddc3"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d66c247c2d3bb7a9b60567c395a15a929d0ebcc5f4ceedb55bfa202c38c6e0c"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:15146301b32e6e3d2b7e8146db1a26747919d8b13690c7f83a4cb5dc111b3a08"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7a03da59b6c7c97e657dd5cd4bcaab5fe4a2affd8193958d6f4d938bee36679"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d2c2fe19e392dbc22695b6c3b2510527e2b774647e79936bbde49db7742d6f1"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:91aaee4c94cb45930684f583ffc4e7c01a52b46610971cede33586cf8a04a12e"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3f5702828c10768f9281180a7ff8597da1e5002803e1304e9519dd0f06d79a85"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ccd1763b608fb4629a0b08f00b3c099d6395e67c14e619f6341b2c8429c2f310"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc7a0d4b2cb166bc46d02c8c9f7551cde8e2f3c9789df3827309433ee9771163"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7496f53d40560a58964207b52586783633f371683834a8f719d6d965d223a2eb"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-win32.whl", hash = "sha256:5eb1a9272ca71bc72be5415c2fa8448a6302ea4578e181bb7da9db855b367df0"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-win_amd64.whl", hash = "sha256:0d21fc3c0ca507a1180152a6dbd129ebaef48facde3f943db5c1055b6e6be56a"}, + {file = "rapidfuzz-3.9.6-cp312-cp312-win_arm64.whl", hash = "sha256:43bb27a57c29dc5fa754496ba6a1a508480d21ae99ac0d19597646c16407e9f3"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:83a5ac6547a9d6eedaa212975cb8f2ce2aa07e6e30833b40e54a52b9f9999aa4"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:10f06139142ecde67078ebc9a745965446132b998f9feebffd71acdf218acfcc"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74720c3f24597f76c7c3e2c4abdff55f1664f4766ff5b28aeaa689f8ffba5fab"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce2bce52b5c150878e558a0418c2b637fb3dbb6eb38e4eb27d24aa839920483e"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1611199f178793ca9a060c99b284e11f6d7d124998191f1cace9a0245334d219"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0308b2ad161daf502908a6e21a57c78ded0258eba9a8f5e2545e2dafca312507"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eda91832201b86e3b70835f91522587725bec329ec68f2f7faf5124091e5ca7"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ece873c093aedd87fc07c2a7e333d52e458dc177016afa1edaf157e82b6914d8"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d97d3c9d209d5c30172baea5966f2129e8a198fec4a1aeb2f92abb6e82a2edb1"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:6c4550d0db4931f5ebe9f0678916d1b06f06f5a99ba0b8a48b9457fd8959a7d4"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b6b8dd4af6324fc325d9483bec75ecf9be33e590928c9202d408e4eafff6a0a6"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:16122ae448bc89e2bea9d81ce6cb0f751e4e07da39bd1e70b95cae2493857853"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-win32.whl", hash = "sha256:71cc168c305a4445109cd0d4925406f6e66bcb48fde99a1835387c58af4ecfe9"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-win_amd64.whl", hash = "sha256:59ee78f2ecd53fef8454909cda7400fe2cfcd820f62b8a5d4dfe930102268054"}, + {file = "rapidfuzz-3.9.6-cp313-cp313-win_arm64.whl", hash = "sha256:58b4ce83f223605c358ae37e7a2d19a41b96aa65b1fede99cc664c9053af89ac"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9f469dbc9c4aeaac7dd005992af74b7dff94aa56a3ea063ce64e4b3e6736dd2f"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a9ed7ad9adb68d0fe63a156fe752bbf5f1403ed66961551e749641af2874da92"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ffe48ffbeedf78d120ddfb9d583f2ca906712159a4e9c3c743c9f33e7b1775"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8502ccdea9084d54b6f737d96a3b60a84e3afed9d016686dc979b49cdac71613"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a4bec4956e06b170ca896ba055d08d4c457dac745548172443982956a80e118"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c0488b1c273be39e109ff885ccac0448b2fa74dea4c4dc676bcf756c15f16d6"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0542c036cb6acf24edd2c9e0411a67d7ba71e29e4d3001a082466b86fc34ff30"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0a96b52c9f26857bf009e270dcd829381e7a634f7ddd585fa29b87d4c82146d9"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:6edd3cd7c4aa8c68c716d349f531bd5011f2ca49ddade216bb4429460151559f"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:50b2fb55d7ed58c66d49c9f954acd8fc4a3f0e9fd0ff708299bd8abb68238d0e"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:32848dfe54391636b84cda1823fd23e5a6b1dbb8be0e9a1d80e4ee9903820994"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:29146cb7a1bf69c87e928b31bffa54f066cb65639d073b36e1425f98cccdebc6"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-win32.whl", hash = "sha256:aed13e5edacb0ecadcc304cc66e93e7e77ff24f059c9792ee602c0381808e10c"}, + {file = "rapidfuzz-3.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:af440e36b828922256d0b4d79443bf2cbe5515fc4b0e9e96017ec789b36bb9fc"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:efa674b407424553024522159296690d99d6e6b1192cafe99ca84592faff16b4"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0b40ff76ee19b03ebf10a0a87938f86814996a822786c41c3312d251b7927849"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16a6c7997cb5927ced6f617122eb116ba514ec6b6f60f4803e7925ef55158891"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3f42504bdc8d770987fc3d99964766d42b2a03e4d5b0f891decdd256236bae0"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9462aa2be9f60b540c19a083471fdf28e7cf6434f068b631525b5e6251b35e"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1629698e68f47609a73bf9e73a6da3a4cac20bc710529215cbdf111ab603665b"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68bc7621843d8e9a7fd1b1a32729465bf94b47b6fb307d906da168413331f8d6"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c6254c50f15bc2fcc33cb93a95a81b702d9e6590f432a7f7822b8c7aba9ae288"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7e535a114fa575bc143e175e4ca386a467ec8c42909eff500f5f0f13dc84e3e0"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:d50acc0e9d67e4ba7a004a14c42d1b1e8b6ca1c515692746f4f8e7948c673167"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:fa742ec60bec53c5a211632cf1d31b9eb5a3c80f1371a46a23ac25a1fa2ab209"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c256fa95d29cbe5aa717db790b231a9a5b49e5983d50dc9df29d364a1db5e35b"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-win32.whl", hash = "sha256:89acbf728b764421036c173a10ada436ecca22999851cdc01d0aa904c70d362d"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:c608fcba8b14d86c04cb56b203fed31a96e8a1ebb4ce99e7b70313c5bf8cf497"}, + {file = "rapidfuzz-3.9.6-cp39-cp39-win_arm64.whl", hash = "sha256:d41c00ded0e22e9dba88ff23ebe0dc9d2a5f21ba2f88e185ea7374461e61daa9"}, + {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a65c2f63218ea2dedd56fc56361035e189ca123bd9c9ce63a9bef6f99540d681"}, + {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:680dc78a5f889d3b89f74824b89fe357f49f88ad10d2c121e9c3ad37bac1e4eb"}, + {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8ca862927a0b05bd825e46ddf82d0724ea44b07d898ef639386530bf9b40f15"}, + {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2116fa1fbff21fa52cd46f3cfcb1e193ba1d65d81f8b6e123193451cd3d6c15e"}, + {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dcb7d9afd740370a897c15da61d3d57a8d54738d7c764a99cedb5f746d6a003"}, + {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1a5bd6401bb489e14cbb5981c378d53ede850b7cc84b2464cad606149cc4e17d"}, + {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:29fda70b9d03e29df6fc45cc27cbcc235534b1b0b2900e0a3ae0b43022aaeef5"}, + {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:88144f5f52ae977df9352029488326afadd7a7f42c6779d486d1f82d43b2b1f2"}, + {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:715aeaabafba2709b9dd91acb2a44bad59d60b4616ef90c08f4d4402a3bbca60"}, + {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af26ebd3714224fbf9bebbc27bdbac14f334c15f5d7043699cd694635050d6ca"}, + {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101bd2df438861a005ed47c032631b7857dfcdb17b82beeeb410307983aac61d"}, + {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:2185e8e29809b97ad22a7f99281d1669a89bdf5fa1ef4ef1feca36924e675367"}, + {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9e53c72d08f0e9c6e4a369e52df5971f311305b4487690c62e8dd0846770260c"}, + {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a0cb157162f0cdd62e538c7bd298ff669847fc43a96422811d5ab933f4c16c3a"}, + {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bb5ff2bd48132ed5e7fbb8f619885facb2e023759f2519a448b2c18afe07e5d"}, + {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6dc37f601865e8407e3a8037ffbc3afe0b0f837b2146f7632bd29d087385babe"}, + {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a657eee4b94668faf1fa2703bdd803654303f7e468eb9ba10a664d867ed9e779"}, + {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:51be6ab5b1d5bb32abd39718f2a5e3835502e026a8272d139ead295c224a6f5e"}, + {file = "rapidfuzz-3.9.6.tar.gz", hash = "sha256:5cf2a7d621e4515fee84722e93563bf77ff2cbe832a77a48b81f88f9e23b9e8d"}, ] [package.extras] @@ -7078,109 +7406,124 @@ test = ["coveralls", "pycodestyle", "pyflakes", "pylint", "pytest", "pytest-benc [[package]] name = "redis" -version = "5.0.7" +version = "5.0.8" description = "Python client for Redis database and key-value store" optional = false python-versions = ">=3.7" files = [ - {file = "redis-5.0.7-py3-none-any.whl", hash = "sha256:0e479e24da960c690be5d9b96d21f7b918a98c0cf49af3b6fafaa0753f93a0db"}, - {file = "redis-5.0.7.tar.gz", hash = "sha256:8f611490b93c8109b50adc317b31bfd84fff31def3475b92e7e80bf39f48175b"}, + {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"}, + {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"}, ] [package.dependencies] async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} -hiredis = {version = ">=1.0.0", optional = true, markers = "extra == \"hiredis\""} +hiredis = {version = ">1.0.0", optional = true, markers = "extra == \"hiredis\""} [package.extras] -hiredis = ["hiredis (>=1.0.0)"] +hiredis = ["hiredis (>1.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "regex" -version = "2024.5.15" +version = "2024.7.24" description = "Alternative regular expression module, to replace re." optional = false python-versions = ">=3.8" files = [ - {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f"}, - {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6"}, - {file = "regex-2024.5.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0bd000c6e266927cb7a1bc39d55be95c4b4f65c5be53e659537537e019232b1"}, - {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eaa7ddaf517aa095fa8da0b5015c44d03da83f5bd49c87961e3c997daed0de7"}, - {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba68168daedb2c0bab7fd7e00ced5ba90aebf91024dea3c88ad5063c2a562cca"}, - {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e8d717bca3a6e2064fc3a08df5cbe366369f4b052dcd21b7416e6d71620dca1"}, - {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1337b7dbef9b2f71121cdbf1e97e40de33ff114801263b275aafd75303bd62b5"}, - {file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9ebd0a36102fcad2f03696e8af4ae682793a5d30b46c647eaf280d6cfb32796"}, - {file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9efa1a32ad3a3ea112224897cdaeb6aa00381627f567179c0314f7b65d354c62"}, - {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1595f2d10dff3d805e054ebdc41c124753631b6a471b976963c7b28543cf13b0"}, - {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b802512f3e1f480f41ab5f2cfc0e2f761f08a1f41092d6718868082fc0d27143"}, - {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a0981022dccabca811e8171f913de05720590c915b033b7e601f35ce4ea7019f"}, - {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:19068a6a79cf99a19ccefa44610491e9ca02c2be3305c7760d3831d38a467a6f"}, - {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b5269484f6126eee5e687785e83c6b60aad7663dafe842b34691157e5083e53"}, - {file = "regex-2024.5.15-cp310-cp310-win32.whl", hash = "sha256:ada150c5adfa8fbcbf321c30c751dc67d2f12f15bd183ffe4ec7cde351d945b3"}, - {file = "regex-2024.5.15-cp310-cp310-win_amd64.whl", hash = "sha256:ac394ff680fc46b97487941f5e6ae49a9f30ea41c6c6804832063f14b2a5a145"}, - {file = "regex-2024.5.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f5b1dff3ad008dccf18e652283f5e5339d70bf8ba7c98bf848ac33db10f7bc7a"}, - {file = "regex-2024.5.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6a2b494a76983df8e3d3feea9b9ffdd558b247e60b92f877f93a1ff43d26656"}, - {file = "regex-2024.5.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a32b96f15c8ab2e7d27655969a23895eb799de3665fa94349f3b2fbfd547236f"}, - {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10002e86e6068d9e1c91eae8295ef690f02f913c57db120b58fdd35a6bb1af35"}, - {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec54d5afa89c19c6dd8541a133be51ee1017a38b412b1321ccb8d6ddbeb4cf7d"}, - {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10e4ce0dca9ae7a66e6089bb29355d4432caed736acae36fef0fdd7879f0b0cb"}, - {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e507ff1e74373c4d3038195fdd2af30d297b4f0950eeda6f515ae3d84a1770f"}, - {file = "regex-2024.5.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1f059a4d795e646e1c37665b9d06062c62d0e8cc3c511fe01315973a6542e40"}, - {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0721931ad5fe0dda45d07f9820b90b2148ccdd8e45bb9e9b42a146cb4f695649"}, - {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:833616ddc75ad595dee848ad984d067f2f31be645d603e4d158bba656bbf516c"}, - {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:287eb7f54fc81546346207c533ad3c2c51a8d61075127d7f6d79aaf96cdee890"}, - {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:19dfb1c504781a136a80ecd1fff9f16dddf5bb43cec6871778c8a907a085bb3d"}, - {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:119af6e56dce35e8dfb5222573b50c89e5508d94d55713c75126b753f834de68"}, - {file = "regex-2024.5.15-cp311-cp311-win32.whl", hash = "sha256:1c1c174d6ec38d6c8a7504087358ce9213d4332f6293a94fbf5249992ba54efa"}, - {file = "regex-2024.5.15-cp311-cp311-win_amd64.whl", hash = "sha256:9e717956dcfd656f5055cc70996ee2cc82ac5149517fc8e1b60261b907740201"}, - {file = "regex-2024.5.15-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:632b01153e5248c134007209b5c6348a544ce96c46005d8456de1d552455b014"}, - {file = "regex-2024.5.15-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e64198f6b856d48192bf921421fdd8ad8eb35e179086e99e99f711957ffedd6e"}, - {file = "regex-2024.5.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68811ab14087b2f6e0fc0c2bae9ad689ea3584cad6917fc57be6a48bbd012c49"}, - {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ec0c2fea1e886a19c3bee0cd19d862b3aa75dcdfb42ebe8ed30708df64687a"}, - {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0c0c0003c10f54a591d220997dd27d953cd9ccc1a7294b40a4be5312be8797b"}, - {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2431b9e263af1953c55abbd3e2efca67ca80a3de8a0437cb58e2421f8184717a"}, - {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a605586358893b483976cffc1723fb0f83e526e8f14c6e6614e75919d9862cf"}, - {file = "regex-2024.5.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391d7f7f1e409d192dba8bcd42d3e4cf9e598f3979cdaed6ab11288da88cb9f2"}, - {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9ff11639a8d98969c863d4617595eb5425fd12f7c5ef6621a4b74b71ed8726d5"}, - {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4eee78a04e6c67e8391edd4dad3279828dd66ac4b79570ec998e2155d2e59fd5"}, - {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8fe45aa3f4aa57faabbc9cb46a93363edd6197cbc43523daea044e9ff2fea83e"}, - {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:d0a3d8d6acf0c78a1fff0e210d224b821081330b8524e3e2bc5a68ef6ab5803d"}, - {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c486b4106066d502495b3025a0a7251bf37ea9540433940a23419461ab9f2a80"}, - {file = "regex-2024.5.15-cp312-cp312-win32.whl", hash = "sha256:c49e15eac7c149f3670b3e27f1f28a2c1ddeccd3a2812cba953e01be2ab9b5fe"}, - {file = "regex-2024.5.15-cp312-cp312-win_amd64.whl", hash = "sha256:673b5a6da4557b975c6c90198588181029c60793835ce02f497ea817ff647cb2"}, - {file = "regex-2024.5.15-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:87e2a9c29e672fc65523fb47a90d429b70ef72b901b4e4b1bd42387caf0d6835"}, - {file = "regex-2024.5.15-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c3bea0ba8b73b71b37ac833a7f3fd53825924165da6a924aec78c13032f20850"}, - {file = "regex-2024.5.15-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bfc4f82cabe54f1e7f206fd3d30fda143f84a63fe7d64a81558d6e5f2e5aaba9"}, - {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5bb9425fe881d578aeca0b2b4b3d314ec88738706f66f219c194d67179337cb"}, - {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64c65783e96e563103d641760664125e91bd85d8e49566ee560ded4da0d3e704"}, - {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf2430df4148b08fb4324b848672514b1385ae3807651f3567871f130a728cc3"}, - {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5397de3219a8b08ae9540c48f602996aa6b0b65d5a61683e233af8605c42b0f2"}, - {file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:455705d34b4154a80ead722f4f185b04c4237e8e8e33f265cd0798d0e44825fa"}, - {file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2b6f1b3bb6f640c1a92be3bbfbcb18657b125b99ecf141fb3310b5282c7d4ed"}, - {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3ad070b823ca5890cab606c940522d05d3d22395d432f4aaaf9d5b1653e47ced"}, - {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5b5467acbfc153847d5adb21e21e29847bcb5870e65c94c9206d20eb4e99a384"}, - {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e6662686aeb633ad65be2a42b4cb00178b3fbf7b91878f9446075c404ada552f"}, - {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:2b4c884767504c0e2401babe8b5b7aea9148680d2e157fa28f01529d1f7fcf67"}, - {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:3cd7874d57f13bf70078f1ff02b8b0aa48d5b9ed25fc48547516c6aba36f5741"}, - {file = "regex-2024.5.15-cp38-cp38-win32.whl", hash = "sha256:e4682f5ba31f475d58884045c1a97a860a007d44938c4c0895f41d64481edbc9"}, - {file = "regex-2024.5.15-cp38-cp38-win_amd64.whl", hash = "sha256:d99ceffa25ac45d150e30bd9ed14ec6039f2aad0ffa6bb87a5936f5782fc1569"}, - {file = "regex-2024.5.15-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13cdaf31bed30a1e1c2453ef6015aa0983e1366fad2667657dbcac7b02f67133"}, - {file = "regex-2024.5.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cac27dcaa821ca271855a32188aa61d12decb6fe45ffe3e722401fe61e323cd1"}, - {file = "regex-2024.5.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7dbe2467273b875ea2de38ded4eba86cbcbc9a1a6d0aa11dcf7bd2e67859c435"}, - {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64f18a9a3513a99c4bef0e3efd4c4a5b11228b48aa80743be822b71e132ae4f5"}, - {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d347a741ea871c2e278fde6c48f85136c96b8659b632fb57a7d1ce1872547600"}, - {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1878b8301ed011704aea4c806a3cadbd76f84dece1ec09cc9e4dc934cfa5d4da"}, - {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4babf07ad476aaf7830d77000874d7611704a7fcf68c9c2ad151f5d94ae4bfc4"}, - {file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35cb514e137cb3488bce23352af3e12fb0dbedd1ee6e60da053c69fb1b29cc6c"}, - {file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cdd09d47c0b2efee9378679f8510ee6955d329424c659ab3c5e3a6edea696294"}, - {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:72d7a99cd6b8f958e85fc6ca5b37c4303294954eac1376535b03c2a43eb72629"}, - {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a094801d379ab20c2135529948cb84d417a2169b9bdceda2a36f5f10977ebc16"}, - {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c0c18345010870e58238790a6779a1219b4d97bd2e77e1140e8ee5d14df071aa"}, - {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:16093f563098448ff6b1fa68170e4acbef94e6b6a4e25e10eae8598bb1694b5d"}, - {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e38a7d4e8f633a33b4c7350fbd8bad3b70bf81439ac67ac38916c4a86b465456"}, - {file = "regex-2024.5.15-cp39-cp39-win32.whl", hash = "sha256:71a455a3c584a88f654b64feccc1e25876066c4f5ef26cd6dd711308aa538694"}, - {file = "regex-2024.5.15-cp39-cp39-win_amd64.whl", hash = "sha256:cab12877a9bdafde5500206d1020a584355a97884dfd388af3699e9137bf7388"}, - {file = "regex-2024.5.15.tar.gz", hash = "sha256:d3ee02d9e5f482cc8309134a91eeaacbdd2261ba111b0fef3748eeb4913e6a2c"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b0d3f567fafa0633aee87f08b9276c7062da9616931382993c03808bb68ce"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3426de3b91d1bc73249042742f45c2148803c111d1175b283270177fdf669024"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f273674b445bcb6e4409bf8d1be67bc4b58e8b46fd0d560055d515b8830063cd"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23acc72f0f4e1a9e6e9843d6328177ae3074b4182167e34119ec7233dfeccf53"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65fd3d2e228cae024c411c5ccdffae4c315271eee4a8b839291f84f796b34eca"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c414cbda77dbf13c3bc88b073a1a9f375c7b0cb5e115e15d4b73ec3a2fbc6f59"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7a89eef64b5455835f5ed30254ec19bf41f7541cd94f266ab7cbd463f00c41"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19c65b00d42804e3fbea9708f0937d157e53429a39b7c61253ff15670ff62cb5"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7a5486ca56c8869070a966321d5ab416ff0f83f30e0e2da1ab48815c8d165d46"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f51f9556785e5a203713f5efd9c085b4a45aecd2a42573e2b5041881b588d1f"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a4997716674d36a82eab3e86f8fa77080a5d8d96a389a61ea1d0e3a94a582cf7"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c0abb5e4e8ce71a61d9446040c1e86d4e6d23f9097275c5bd49ed978755ff0fe"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:18300a1d78cf1290fa583cd8b7cde26ecb73e9f5916690cf9d42de569c89b1ce"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:416c0e4f56308f34cdb18c3f59849479dde5b19febdcd6e6fa4d04b6c31c9faa"}, + {file = "regex-2024.7.24-cp310-cp310-win32.whl", hash = "sha256:fb168b5924bef397b5ba13aabd8cf5df7d3d93f10218d7b925e360d436863f66"}, + {file = "regex-2024.7.24-cp310-cp310-win_amd64.whl", hash = "sha256:6b9fc7e9cc983e75e2518496ba1afc524227c163e43d706688a6bb9eca41617e"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:382281306e3adaaa7b8b9ebbb3ffb43358a7bbf585fa93821300a418bb975281"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fdd1384619f406ad9037fe6b6eaa3de2749e2e12084abc80169e8e075377d3b"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3d974d24edb231446f708c455fd08f94c41c1ff4f04bcf06e5f36df5ef50b95a"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec4419a3fe6cf8a4795752596dfe0adb4aea40d3683a132bae9c30b81e8d73"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb563dd3aea54c797adf513eeec819c4213d7dbfc311874eb4fd28d10f2ff0f2"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45104baae8b9f67569f0f1dca5e1f1ed77a54ae1cd8b0b07aba89272710db61e"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:994448ee01864501912abf2bad9203bffc34158e80fe8bfb5b031f4f8e16da51"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fac296f99283ac232d8125be932c5cd7644084a30748fda013028c815ba3364"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7e37e809b9303ec3a179085415cb5f418ecf65ec98cdfe34f6a078b46ef823ee"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:01b689e887f612610c869421241e075c02f2e3d1ae93a037cb14f88ab6a8934c"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f6442f0f0ff81775eaa5b05af8a0ffa1dda36e9cf6ec1e0d3d245e8564b684ce"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:871e3ab2838fbcb4e0865a6e01233975df3a15e6fce93b6f99d75cacbd9862d1"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c918b7a1e26b4ab40409820ddccc5d49871a82329640f5005f73572d5eaa9b5e"}, + {file = "regex-2024.7.24-cp311-cp311-win32.whl", hash = "sha256:2dfbb8baf8ba2c2b9aa2807f44ed272f0913eeeba002478c4577b8d29cde215c"}, + {file = "regex-2024.7.24-cp311-cp311-win_amd64.whl", hash = "sha256:538d30cd96ed7d1416d3956f94d54e426a8daf7c14527f6e0d6d425fcb4cca52"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fe4ebef608553aff8deb845c7f4f1d0740ff76fa672c011cc0bacb2a00fbde86"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:74007a5b25b7a678459f06559504f1eec2f0f17bca218c9d56f6a0a12bfffdad"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7df9ea48641da022c2a3c9c641650cd09f0cd15e8908bf931ad538f5ca7919c9"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1141a1dcc32904c47f6846b040275c6e5de0bf73f17d7a409035d55b76f289"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80c811cfcb5c331237d9bad3bea2c391114588cf4131707e84d9493064d267f9"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7214477bf9bd195894cf24005b1e7b496f46833337b5dedb7b2a6e33f66d962c"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d55588cba7553f0b6ec33130bc3e114b355570b45785cebdc9daed8c637dd440"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558a57cfc32adcf19d3f791f62b5ff564922942e389e3cfdb538a23d65a6b610"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a512eed9dfd4117110b1881ba9a59b31433caed0c4101b361f768e7bcbaf93c5"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:86b17ba823ea76256b1885652e3a141a99a5c4422f4a869189db328321b73799"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5eefee9bfe23f6df09ffb6dfb23809f4d74a78acef004aa904dc7c88b9944b05"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:731fcd76bbdbf225e2eb85b7c38da9633ad3073822f5ab32379381e8c3c12e94"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eaef80eac3b4cfbdd6de53c6e108b4c534c21ae055d1dbea2de6b3b8ff3def38"}, + {file = "regex-2024.7.24-cp312-cp312-win32.whl", hash = "sha256:185e029368d6f89f36e526764cf12bf8d6f0e3a2a7737da625a76f594bdfcbfc"}, + {file = "regex-2024.7.24-cp312-cp312-win_amd64.whl", hash = "sha256:2f1baff13cc2521bea83ab2528e7a80cbe0ebb2c6f0bfad15be7da3aed443908"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:66b4c0731a5c81921e938dcf1a88e978264e26e6ac4ec96a4d21ae0354581ae0"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:88ecc3afd7e776967fa16c80f974cb79399ee8dc6c96423321d6f7d4b881c92b"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64bd50cf16bcc54b274e20235bf8edbb64184a30e1e53873ff8d444e7ac656b2"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb462f0e346fcf41a901a126b50f8781e9a474d3927930f3490f38a6e73b6950"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a82465ebbc9b1c5c50738536fdfa7cab639a261a99b469c9d4c7dcbb2b3f1e57"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68a8f8c046c6466ac61a36b65bb2395c74451df2ffb8458492ef49900efed293"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac8e84fff5d27420f3c1e879ce9929108e873667ec87e0c8eeb413a5311adfe"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba2537ef2163db9e6ccdbeb6f6424282ae4dea43177402152c67ef869cf3978b"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:43affe33137fcd679bdae93fb25924979517e011f9dea99163f80b82eadc7e53"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:c9bb87fdf2ab2370f21e4d5636e5317775e5d51ff32ebff2cf389f71b9b13750"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:945352286a541406f99b2655c973852da7911b3f4264e010218bbc1cc73168f2"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8bc593dcce679206b60a538c302d03c29b18e3d862609317cb560e18b66d10cf"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:3f3b6ca8eae6d6c75a6cff525c8530c60e909a71a15e1b731723233331de4169"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c51edc3541e11fbe83f0c4d9412ef6c79f664a3745fab261457e84465ec9d5a8"}, + {file = "regex-2024.7.24-cp38-cp38-win32.whl", hash = "sha256:d0a07763776188b4db4c9c7fb1b8c494049f84659bb387b71c73bbc07f189e96"}, + {file = "regex-2024.7.24-cp38-cp38-win_amd64.whl", hash = "sha256:8fd5afd101dcf86a270d254364e0e8dddedebe6bd1ab9d5f732f274fa00499a5"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0ffe3f9d430cd37d8fa5632ff6fb36d5b24818c5c986893063b4e5bdb84cdf24"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25419b70ba00a16abc90ee5fce061228206173231f004437730b67ac77323f0d"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33e2614a7ce627f0cdf2ad104797d1f68342d967de3695678c0cb84f530709f8"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33a0021893ede5969876052796165bab6006559ab845fd7b515a30abdd990dc"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04ce29e2c5fedf296b1a1b0acc1724ba93a36fb14031f3abfb7abda2806c1535"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b16582783f44fbca6fcf46f61347340c787d7530d88b4d590a397a47583f31dd"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:836d3cc225b3e8a943d0b02633fb2f28a66e281290302a79df0e1eaa984ff7c1"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:438d9f0f4bc64e8dea78274caa5af971ceff0f8771e1a2333620969936ba10be"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:973335b1624859cb0e52f96062a28aa18f3a5fc77a96e4a3d6d76e29811a0e6e"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c5e69fd3eb0b409432b537fe3c6f44ac089c458ab6b78dcec14478422879ec5f"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fbf8c2f00904eaf63ff37718eb13acf8e178cb940520e47b2f05027f5bb34ce3"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2757ace61bc4061b69af19e4689fa4416e1a04840f33b441034202b5cd02d4"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:44fc61b99035fd9b3b9453f1713234e5a7c92a04f3577252b45feefe1b327759"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:84c312cdf839e8b579f504afcd7b65f35d60b6285d892b19adea16355e8343c9"}, + {file = "regex-2024.7.24-cp39-cp39-win32.whl", hash = "sha256:ca5b2028c2f7af4e13fb9fc29b28d0ce767c38c7facdf64f6c2cd040413055f1"}, + {file = "regex-2024.7.24-cp39-cp39-win_amd64.whl", hash = "sha256:7c479f5ae937ec9985ecaf42e2e10631551d909f203e31308c12d703922742f9"}, + {file = "regex-2024.7.24.tar.gz", hash = "sha256:9cfd009eed1a46b27c14039ad5bbc5e71b6367c5b2e6d5f5da0ea91600817506"}, ] [[package]] @@ -7302,6 +7645,118 @@ pygments = ">=2.13.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rpds-py" +version = "0.20.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, +] + [[package]] name = "rsa" version = "4.9" @@ -7318,29 +7773,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.5.4" +version = "0.6.1" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.4-py3-none-linux_armv6l.whl", hash = "sha256:82acef724fc639699b4d3177ed5cc14c2a5aacd92edd578a9e846d5b5ec18ddf"}, - {file = "ruff-0.5.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:da62e87637c8838b325e65beee485f71eb36202ce8e3cdbc24b9fcb8b99a37be"}, - {file = "ruff-0.5.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e98ad088edfe2f3b85a925ee96da652028f093d6b9b56b76fc242d8abb8e2059"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c55efbecc3152d614cfe6c2247a3054cfe358cefbf794f8c79c8575456efe19"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9b85eaa1f653abd0a70603b8b7008d9e00c9fa1bbd0bf40dad3f0c0bdd06793"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0cf497a47751be8c883059c4613ba2f50dd06ec672692de2811f039432875278"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:09c14ed6a72af9ccc8d2e313d7acf7037f0faff43cde4b507e66f14e812e37f7"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:628f6b8f97b8bad2490240aa84f3e68f390e13fabc9af5c0d3b96b485921cd60"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3520a00c0563d7a7a7c324ad7e2cde2355733dafa9592c671fb2e9e3cd8194c1"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93789f14ca2244fb91ed481456f6d0bb8af1f75a330e133b67d08f06ad85b516"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:029454e2824eafa25b9df46882f7f7844d36fd8ce51c1b7f6d97e2615a57bbcc"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9492320eed573a13a0bc09a2957f17aa733fff9ce5bf00e66e6d4a88ec33813f"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a6e1f62a92c645e2919b65c02e79d1f61e78a58eddaebca6c23659e7c7cb4ac7"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:768fa9208df2bec4b2ce61dbc7c2ddd6b1be9fb48f1f8d3b78b3332c7d71c1ff"}, - {file = "ruff-0.5.4-py3-none-win32.whl", hash = "sha256:e1e7393e9c56128e870b233c82ceb42164966f25b30f68acbb24ed69ce9c3a4e"}, - {file = "ruff-0.5.4-py3-none-win_amd64.whl", hash = "sha256:58b54459221fd3f661a7329f177f091eb35cf7a603f01d9eb3eb11cc348d38c4"}, - {file = "ruff-0.5.4-py3-none-win_arm64.whl", hash = "sha256:bd53da65f1085fb5b307c38fd3c0829e76acf7b2a912d8d79cadcdb4875c1eb7"}, - {file = "ruff-0.5.4.tar.gz", hash = "sha256:2795726d5f71c4f4e70653273d1c23a8182f07dd8e48c12de5d867bfb7557eed"}, + {file = "ruff-0.6.1-py3-none-linux_armv6l.whl", hash = "sha256:b4bb7de6a24169dc023f992718a9417380301b0c2da0fe85919f47264fb8add9"}, + {file = "ruff-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:45efaae53b360c81043e311cdec8a7696420b3d3e8935202c2846e7a97d4edae"}, + {file = "ruff-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bc60c7d71b732c8fa73cf995efc0c836a2fd8b9810e115be8babb24ae87e0850"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c7477c3b9da822e2db0b4e0b59e61b8a23e87886e727b327e7dcaf06213c5cf"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a0af7ab3f86e3dc9f157a928e08e26c4b40707d0612b01cd577cc84b8905cc9"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:392688dbb50fecf1bf7126731c90c11a9df1c3a4cdc3f481b53e851da5634fa5"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5278d3e095ccc8c30430bcc9bc550f778790acc211865520f3041910a28d0024"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe6d5f65d6f276ee7a0fc50a0cecaccb362d30ef98a110f99cac1c7872df2f18"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e0dd11e2ae553ee5c92a81731d88a9883af8db7408db47fc81887c1f8b672e"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d812615525a34ecfc07fd93f906ef5b93656be01dfae9a819e31caa6cfe758a1"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faaa4060f4064c3b7aaaa27328080c932fa142786f8142aff095b42b6a2eb631"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:99d7ae0df47c62729d58765c593ea54c2546d5de213f2af2a19442d50a10cec9"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9eb18dfd7b613eec000e3738b3f0e4398bf0153cb80bfa3e351b3c1c2f6d7b15"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c62bc04c6723a81e25e71715aa59489f15034d69bf641df88cb38bdc32fd1dbb"}, + {file = "ruff-0.6.1-py3-none-win32.whl", hash = "sha256:9fb4c4e8b83f19c9477a8745e56d2eeef07a7ff50b68a6998f7d9e2e3887bdc4"}, + {file = "ruff-0.6.1-py3-none-win_amd64.whl", hash = "sha256:c2ebfc8f51ef4aca05dad4552bbcf6fe8d1f75b2f6af546cc47cc1c1ca916b5b"}, + {file = "ruff-0.6.1-py3-none-win_arm64.whl", hash = "sha256:3bc81074971b0ffad1bd0c52284b22411f02a11a012082a76ac6da153536e014"}, + {file = "ruff-0.6.1.tar.gz", hash = "sha256:af3ffd8c6563acb8848d33cd19a69b9bfe943667f0419ca083f8ebe4224a3436"}, ] [[package]] @@ -7362,111 +7817,121 @@ crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] [[package]] name = "safetensors" -version = "0.4.3" +version = "0.4.4" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.4.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:dcf5705cab159ce0130cd56057f5f3425023c407e170bca60b4868048bae64fd"}, - {file = "safetensors-0.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bb4f8c5d0358a31e9a08daeebb68f5e161cdd4018855426d3f0c23bb51087055"}, - {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70a5319ef409e7f88686a46607cbc3c428271069d8b770076feaf913664a07ac"}, - {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb9c65bd82f9ef3ce4970dc19ee86be5f6f93d032159acf35e663c6bea02b237"}, - {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:edb5698a7bc282089f64c96c477846950358a46ede85a1c040e0230344fdde10"}, - {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:efcc860be094b8d19ac61b452ec635c7acb9afa77beb218b1d7784c6d41fe8ad"}, - {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d88b33980222085dd6001ae2cad87c6068e0991d4f5ccf44975d216db3b57376"}, - {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5fc6775529fb9f0ce2266edd3e5d3f10aab068e49f765e11f6f2a63b5367021d"}, - {file = "safetensors-0.4.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9c6ad011c1b4e3acff058d6b090f1da8e55a332fbf84695cf3100c649cc452d1"}, - {file = "safetensors-0.4.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c496c5401c1b9c46d41a7688e8ff5b0310a3b9bae31ce0f0ae870e1ea2b8caf"}, - {file = "safetensors-0.4.3-cp310-none-win32.whl", hash = "sha256:38e2a8666178224a51cca61d3cb4c88704f696eac8f72a49a598a93bbd8a4af9"}, - {file = "safetensors-0.4.3-cp310-none-win_amd64.whl", hash = "sha256:393e6e391467d1b2b829c77e47d726f3b9b93630e6a045b1d1fca67dc78bf632"}, - {file = "safetensors-0.4.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:22f3b5d65e440cec0de8edaa672efa888030802e11c09b3d6203bff60ebff05a"}, - {file = "safetensors-0.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c4fa560ebd4522adddb71dcd25d09bf211b5634003f015a4b815b7647d62ebe"}, - {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9afd5358719f1b2cf425fad638fc3c887997d6782da317096877e5b15b2ce93"}, - {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d8c5093206ef4b198600ae484230402af6713dab1bd5b8e231905d754022bec7"}, - {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0b2104df1579d6ba9052c0ae0e3137c9698b2d85b0645507e6fd1813b70931a"}, - {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8cf18888606dad030455d18f6c381720e57fc6a4170ee1966adb7ebc98d4d6a3"}, - {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bf4f9d6323d9f86eef5567eabd88f070691cf031d4c0df27a40d3b4aaee755b"}, - {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:585c9ae13a205807b63bef8a37994f30c917ff800ab8a1ca9c9b5d73024f97ee"}, - {file = "safetensors-0.4.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faefeb3b81bdfb4e5a55b9bbdf3d8d8753f65506e1d67d03f5c851a6c87150e9"}, - {file = "safetensors-0.4.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:befdf0167ad626f22f6aac6163477fcefa342224a22f11fdd05abb3995c1783c"}, - {file = "safetensors-0.4.3-cp311-none-win32.whl", hash = "sha256:a7cef55929dcbef24af3eb40bedec35d82c3c2fa46338bb13ecf3c5720af8a61"}, - {file = "safetensors-0.4.3-cp311-none-win_amd64.whl", hash = "sha256:840b7ac0eff5633e1d053cc9db12fdf56b566e9403b4950b2dc85393d9b88d67"}, - {file = "safetensors-0.4.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:22d21760dc6ebae42e9c058d75aa9907d9f35e38f896e3c69ba0e7b213033856"}, - {file = "safetensors-0.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d22c1a10dff3f64d0d68abb8298a3fd88ccff79f408a3e15b3e7f637ef5c980"}, - {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1648568667f820b8c48317c7006221dc40aced1869908c187f493838a1362bc"}, - {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:446e9fe52c051aeab12aac63d1017e0f68a02a92a027b901c4f8e931b24e5397"}, - {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fef5d70683643618244a4f5221053567ca3e77c2531e42ad48ae05fae909f542"}, - {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a1f4430cc0c9d6afa01214a4b3919d0a029637df8e09675ceef1ca3f0dfa0df"}, - {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d603846a8585b9432a0fd415db1d4c57c0f860eb4aea21f92559ff9902bae4d"}, - {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a844cdb5d7cbc22f5f16c7e2a0271170750763c4db08381b7f696dbd2c78a361"}, - {file = "safetensors-0.4.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:88887f69f7a00cf02b954cdc3034ffb383b2303bc0ab481d4716e2da51ddc10e"}, - {file = "safetensors-0.4.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ee463219d9ec6c2be1d331ab13a8e0cd50d2f32240a81d498266d77d07b7e71e"}, - {file = "safetensors-0.4.3-cp312-none-win32.whl", hash = "sha256:d0dd4a1db09db2dba0f94d15addc7e7cd3a7b0d393aa4c7518c39ae7374623c3"}, - {file = "safetensors-0.4.3-cp312-none-win_amd64.whl", hash = "sha256:d14d30c25897b2bf19b6fb5ff7e26cc40006ad53fd4a88244fdf26517d852dd7"}, - {file = "safetensors-0.4.3-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d1456f814655b224d4bf6e7915c51ce74e389b413be791203092b7ff78c936dd"}, - {file = "safetensors-0.4.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:455d538aa1aae4a8b279344a08136d3f16334247907b18a5c3c7fa88ef0d3c46"}, - {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf476bca34e1340ee3294ef13e2c625833f83d096cfdf69a5342475602004f95"}, - {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02ef3a24face643456020536591fbd3c717c5abaa2737ec428ccbbc86dffa7a4"}, - {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7de32d0d34b6623bb56ca278f90db081f85fb9c5d327e3c18fd23ac64f465768"}, - {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a0deb16a1d3ea90c244ceb42d2c6c276059616be21a19ac7101aa97da448faf"}, - {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c59d51f182c729f47e841510b70b967b0752039f79f1de23bcdd86462a9b09ee"}, - {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1f598b713cc1a4eb31d3b3203557ac308acf21c8f41104cdd74bf640c6e538e3"}, - {file = "safetensors-0.4.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5757e4688f20df083e233b47de43845d1adb7e17b6cf7da5f8444416fc53828d"}, - {file = "safetensors-0.4.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fe746d03ed8d193674a26105e4f0fe6c726f5bb602ffc695b409eaf02f04763d"}, - {file = "safetensors-0.4.3-cp37-none-win32.whl", hash = "sha256:0d5ffc6a80f715c30af253e0e288ad1cd97a3d0086c9c87995e5093ebc075e50"}, - {file = "safetensors-0.4.3-cp37-none-win_amd64.whl", hash = "sha256:a11c374eb63a9c16c5ed146457241182f310902bd2a9c18255781bb832b6748b"}, - {file = "safetensors-0.4.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:b1e31be7945f66be23f4ec1682bb47faa3df34cb89fc68527de6554d3c4258a4"}, - {file = "safetensors-0.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:03a4447c784917c9bf01d8f2ac5080bc15c41692202cd5f406afba16629e84d6"}, - {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d244bcafeb1bc06d47cfee71727e775bca88a8efda77a13e7306aae3813fa7e4"}, - {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53c4879b9c6bd7cd25d114ee0ef95420e2812e676314300624594940a8d6a91f"}, - {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74707624b81f1b7f2b93f5619d4a9f00934d5948005a03f2c1845ffbfff42212"}, - {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d52c958dc210265157573f81d34adf54e255bc2b59ded6218500c9b15a750eb"}, - {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f9568f380f513a60139971169c4a358b8731509cc19112369902eddb33faa4d"}, - {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0d9cd8e1560dfc514b6d7859247dc6a86ad2f83151a62c577428d5102d872721"}, - {file = "safetensors-0.4.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:89f9f17b0dacb913ed87d57afbc8aad85ea42c1085bd5de2f20d83d13e9fc4b2"}, - {file = "safetensors-0.4.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1139eb436fd201c133d03c81209d39ac57e129f5e74e34bb9ab60f8d9b726270"}, - {file = "safetensors-0.4.3-cp38-none-win32.whl", hash = "sha256:d9c289f140a9ae4853fc2236a2ffc9a9f2d5eae0cb673167e0f1b8c18c0961ac"}, - {file = "safetensors-0.4.3-cp38-none-win_amd64.whl", hash = "sha256:622afd28968ef3e9786562d352659a37de4481a4070f4ebac883f98c5836563e"}, - {file = "safetensors-0.4.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8651c7299cbd8b4161a36cd6a322fa07d39cd23535b144d02f1c1972d0c62f3c"}, - {file = "safetensors-0.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e375d975159ac534c7161269de24ddcd490df2157b55c1a6eeace6cbb56903f0"}, - {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:084fc436e317f83f7071fc6a62ca1c513b2103db325cd09952914b50f51cf78f"}, - {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:41a727a7f5e6ad9f1db6951adee21bbdadc632363d79dc434876369a17de6ad6"}, - {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7dbbde64b6c534548696808a0e01276d28ea5773bc9a2dfb97a88cd3dffe3df"}, - {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bbae3b4b9d997971431c346edbfe6e41e98424a097860ee872721e176040a893"}, - {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01e4b22e3284cd866edeabe4f4d896229495da457229408d2e1e4810c5187121"}, - {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dd37306546b58d3043eb044c8103a02792cc024b51d1dd16bd3dd1f334cb3ed"}, - {file = "safetensors-0.4.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d8815b5e1dac85fc534a97fd339e12404db557878c090f90442247e87c8aeaea"}, - {file = "safetensors-0.4.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e011cc162503c19f4b1fd63dfcddf73739c7a243a17dac09b78e57a00983ab35"}, - {file = "safetensors-0.4.3-cp39-none-win32.whl", hash = "sha256:01feb3089e5932d7e662eda77c3ecc389f97c0883c4a12b5cfdc32b589a811c3"}, - {file = "safetensors-0.4.3-cp39-none-win_amd64.whl", hash = "sha256:3f9cdca09052f585e62328c1c2923c70f46814715c795be65f0b93f57ec98a02"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1b89381517891a7bb7d1405d828b2bf5d75528299f8231e9346b8eba092227f9"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:cd6fff9e56df398abc5866b19a32124815b656613c1c5ec0f9350906fd798aac"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:840caf38d86aa7014fe37ade5d0d84e23dcfbc798b8078015831996ecbc206a3"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9650713b2cfa9537a2baf7dd9fee458b24a0aaaa6cafcea8bdd5fb2b8efdc34"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e4119532cd10dba04b423e0f86aecb96cfa5a602238c0aa012f70c3a40c44b50"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e066e8861eef6387b7c772344d1fe1f9a72800e04ee9a54239d460c400c72aab"}, - {file = "safetensors-0.4.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:90964917f5b0fa0fa07e9a051fbef100250c04d150b7026ccbf87a34a54012e0"}, - {file = "safetensors-0.4.3-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c41e1893d1206aa7054029681778d9a58b3529d4c807002c156d58426c225173"}, - {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae7613a119a71a497d012ccc83775c308b9c1dab454806291427f84397d852fd"}, - {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9bac020faba7f5dc481e881b14b6425265feabb5bfc552551d21189c0eddc3"}, - {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:420a98f593ff9930f5822560d14c395ccbc57342ddff3b463bc0b3d6b1951550"}, - {file = "safetensors-0.4.3-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f5e6883af9a68c0028f70a4c19d5a6ab6238a379be36ad300a22318316c00cb0"}, - {file = "safetensors-0.4.3-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:cdd0a3b5da66e7f377474599814dbf5cbf135ff059cc73694de129b58a5e8a2c"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9bfb92f82574d9e58401d79c70c716985dc049b635fef6eecbb024c79b2c46ad"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:3615a96dd2dcc30eb66d82bc76cda2565f4f7bfa89fcb0e31ba3cea8a1a9ecbb"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:868ad1b6fc41209ab6bd12f63923e8baeb1a086814cb2e81a65ed3d497e0cf8f"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ffba80aa49bd09195145a7fd233a7781173b422eeb995096f2b30591639517"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0acbe31340ab150423347e5b9cc595867d814244ac14218932a5cf1dd38eb39"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19bbdf95de2cf64f25cd614c5236c8b06eb2cfa47cbf64311f4b5d80224623a3"}, - {file = "safetensors-0.4.3-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b852e47eb08475c2c1bd8131207b405793bfc20d6f45aff893d3baaad449ed14"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5d07cbca5b99babb692d76d8151bec46f461f8ad8daafbfd96b2fca40cadae65"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1ab6527a20586d94291c96e00a668fa03f86189b8a9defa2cdd34a1a01acc7d5"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02318f01e332cc23ffb4f6716e05a492c5f18b1d13e343c49265149396284a44"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec4b52ce9a396260eb9731eb6aea41a7320de22ed73a1042c2230af0212758ce"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:018b691383026a2436a22b648873ed11444a364324e7088b99cd2503dd828400"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:309b10dbcab63269ecbf0e2ca10ce59223bb756ca5d431ce9c9eeabd446569da"}, - {file = "safetensors-0.4.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b277482120df46e27a58082df06a15aebda4481e30a1c21eefd0921ae7e03f65"}, - {file = "safetensors-0.4.3.tar.gz", hash = "sha256:2f85fc50c4e07a21e95c24e07460fe6f7e2859d0ce88092838352b798ce711c2"}, + {file = "safetensors-0.4.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2adb497ada13097f30e386e88c959c0fda855a5f6f98845710f5bb2c57e14f12"}, + {file = "safetensors-0.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7db7fdc2d71fd1444d85ca3f3d682ba2df7d61a637dfc6d80793f439eae264ab"}, + {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d4f0eed76b430f009fbefca1a0028ddb112891b03cb556d7440d5cd68eb89a9"}, + {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:57d216fab0b5c432aabf7170883d7c11671622bde8bd1436c46d633163a703f6"}, + {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d9b76322e49c056bcc819f8bdca37a2daa5a6d42c07f30927b501088db03309"}, + {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32f0d1f6243e90ee43bc6ee3e8c30ac5b09ca63f5dd35dbc985a1fc5208c451a"}, + {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d464bdc384874601a177375028012a5f177f1505279f9456fea84bbc575c7f"}, + {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63144e36209ad8e4e65384dbf2d52dd5b1866986079c00a72335402a38aacdc5"}, + {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:051d5ecd490af7245258000304b812825974d5e56f14a3ff7e1b8b2ba6dc2ed4"}, + {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:51bc8429d9376224cd3cf7e8ce4f208b4c930cd10e515b6ac6a72cbc3370f0d9"}, + {file = "safetensors-0.4.4-cp310-none-win32.whl", hash = "sha256:fb7b54830cee8cf9923d969e2df87ce20e625b1af2fd194222ab902d3adcc29c"}, + {file = "safetensors-0.4.4-cp310-none-win_amd64.whl", hash = "sha256:4b3e8aa8226d6560de8c2b9d5ff8555ea482599c670610758afdc97f3e021e9c"}, + {file = "safetensors-0.4.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bbaa31f2cb49013818bde319232ccd72da62ee40f7d2aa532083eda5664e85ff"}, + {file = "safetensors-0.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9fdcb80f4e9fbb33b58e9bf95e7dbbedff505d1bcd1c05f7c7ce883632710006"}, + {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55c14c20be247b8a1aeaf3ab4476265e3ca83096bb8e09bb1a7aa806088def4f"}, + {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:949aaa1118660f992dbf0968487b3e3cfdad67f948658ab08c6b5762e90cc8b6"}, + {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c11a4ab7debc456326a2bac67f35ee0ac792bcf812c7562a4a28559a5c795e27"}, + {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0cea44bba5c5601b297bc8307e4075535b95163402e4906b2e9b82788a2a6df"}, + {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9d752c97f6bbe327352f76e5b86442d776abc789249fc5e72eacb49e6916482"}, + {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03f2bb92e61b055ef6cc22883ad1ae898010a95730fa988c60a23800eb742c2c"}, + {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf3f91a9328a941acc44eceffd4e1f5f89b030985b2966637e582157173b98"}, + {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:20d218ec2b6899d29d6895419a58b6e44cc5ff8f0cc29fac8d236a8978ab702e"}, + {file = "safetensors-0.4.4-cp311-none-win32.whl", hash = "sha256:8079486118919f600c603536e2490ca37b3dbd3280e3ad6eaacfe6264605ac8a"}, + {file = "safetensors-0.4.4-cp311-none-win_amd64.whl", hash = "sha256:2f8c2eb0615e2e64ee27d478c7c13f51e5329d7972d9e15528d3e4cfc4a08f0d"}, + {file = "safetensors-0.4.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baec5675944b4a47749c93c01c73d826ef7d42d36ba8d0dba36336fa80c76426"}, + {file = "safetensors-0.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f15117b96866401825f3e94543145028a2947d19974429246ce59403f49e77c6"}, + {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a13a9caea485df164c51be4eb0c87f97f790b7c3213d635eba2314d959fe929"}, + {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b54bc4ca5f9b9bba8cd4fb91c24b2446a86b5ae7f8975cf3b7a277353c3127c"}, + {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08332c22e03b651c8eb7bf5fc2de90044f3672f43403b3d9ac7e7e0f4f76495e"}, + {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bb62841e839ee992c37bb75e75891c7f4904e772db3691c59daaca5b4ab960e1"}, + {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e5b927acc5f2f59547270b0309a46d983edc44be64e1ca27a7fcb0474d6cd67"}, + {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a69c71b1ae98a8021a09a0b43363b0143b0ce74e7c0e83cacba691b62655fb8"}, + {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23654ad162c02a5636f0cd520a0310902c4421aab1d91a0b667722a4937cc445"}, + {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0677c109d949cf53756859160b955b2e75b0eefe952189c184d7be30ecf7e858"}, + {file = "safetensors-0.4.4-cp312-none-win32.whl", hash = "sha256:a51d0ddd4deb8871c6de15a772ef40b3dbd26a3c0451bb9e66bc76fc5a784e5b"}, + {file = "safetensors-0.4.4-cp312-none-win_amd64.whl", hash = "sha256:2d065059e75a798bc1933c293b68d04d79b586bb7f8c921e0ca1e82759d0dbb1"}, + {file = "safetensors-0.4.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9d625692578dd40a112df30c02a1adf068027566abd8e6a74893bb13d441c150"}, + {file = "safetensors-0.4.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7cabcf39c81e5b988d0adefdaea2eb9b4fd9bd62d5ed6559988c62f36bfa9a89"}, + {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8359bef65f49d51476e9811d59c015f0ddae618ee0e44144f5595278c9f8268c"}, + {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a32c662e7df9226fd850f054a3ead0e4213a96a70b5ce37b2d26ba27004e013"}, + {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c329a4dcc395364a1c0d2d1574d725fe81a840783dda64c31c5a60fc7d41472c"}, + {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:239ee093b1db877c9f8fe2d71331a97f3b9c7c0d3ab9f09c4851004a11f44b65"}, + {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd574145d930cf9405a64f9923600879a5ce51d9f315443a5f706374841327b6"}, + {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f6784eed29f9e036acb0b7769d9e78a0dc2c72c2d8ba7903005350d817e287a4"}, + {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:65a4a6072436bf0a4825b1c295d248cc17e5f4651e60ee62427a5bcaa8622a7a"}, + {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:df81e3407630de060ae8313da49509c3caa33b1a9415562284eaf3d0c7705f9f"}, + {file = "safetensors-0.4.4-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:e4a0f374200e8443d9746e947ebb346c40f83a3970e75a685ade0adbba5c48d9"}, + {file = "safetensors-0.4.4-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:181fb5f3dee78dae7fd7ec57d02e58f7936498d587c6b7c1c8049ef448c8d285"}, + {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb4ac1d8f6b65ec84ddfacd275079e89d9df7c92f95675ba96c4f790a64df6e"}, + {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:76897944cd9239e8a70955679b531b9a0619f76e25476e57ed373322d9c2075d"}, + {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a9e9d1a27e51a0f69e761a3d581c3af46729ec1c988fa1f839e04743026ae35"}, + {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:005ef9fc0f47cb9821c40793eb029f712e97278dae84de91cb2b4809b856685d"}, + {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26987dac3752688c696c77c3576f951dbbdb8c57f0957a41fb6f933cf84c0b62"}, + {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c05270b290acd8d249739f40d272a64dd597d5a4b90f27d830e538bc2549303c"}, + {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:068d3a33711fc4d93659c825a04480ff5a3854e1d78632cdc8f37fee917e8a60"}, + {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:063421ef08ca1021feea8b46951251b90ae91f899234dd78297cbe7c1db73b99"}, + {file = "safetensors-0.4.4-cp37-none-win32.whl", hash = "sha256:d52f5d0615ea83fd853d4e1d8acf93cc2e0223ad4568ba1e1f6ca72e94ea7b9d"}, + {file = "safetensors-0.4.4-cp37-none-win_amd64.whl", hash = "sha256:88a5ac3280232d4ed8e994cbc03b46a1807ce0aa123867b40c4a41f226c61f94"}, + {file = "safetensors-0.4.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:3467ab511bfe3360967d7dc53b49f272d59309e57a067dd2405b4d35e7dcf9dc"}, + {file = "safetensors-0.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2ab4c96d922e53670ce25fbb9b63d5ea972e244de4fa1dd97b590d9fd66aacef"}, + {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87df18fce4440477c3ef1fd7ae17c704a69a74a77e705a12be135ee0651a0c2d"}, + {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e5fe345b2bc7d88587149ac11def1f629d2671c4c34f5df38aed0ba59dc37f8"}, + {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9f1a3e01dce3cd54060791e7e24588417c98b941baa5974700eeb0b8eb65b0a0"}, + {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c6bf35e9a8998d8339fd9a05ac4ce465a4d2a2956cc0d837b67c4642ed9e947"}, + {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:166c0c52f6488b8538b2a9f3fbc6aad61a7261e170698779b371e81b45f0440d"}, + {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87e9903b8668a16ef02c08ba4ebc91e57a49c481e9b5866e31d798632805014b"}, + {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a9c421153aa23c323bd8483d4155b4eee82c9a50ac11cccd83539104a8279c64"}, + {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a4b8617499b2371c7353302c5116a7e0a3a12da66389ce53140e607d3bf7b3d3"}, + {file = "safetensors-0.4.4-cp38-none-win32.whl", hash = "sha256:c6280f5aeafa1731f0a3709463ab33d8e0624321593951aefada5472f0b313fd"}, + {file = "safetensors-0.4.4-cp38-none-win_amd64.whl", hash = "sha256:6ceed6247fc2d33b2a7b7d25d8a0fe645b68798856e0bc7a9800c5fd945eb80f"}, + {file = "safetensors-0.4.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5cf6c6f6193797372adf50c91d0171743d16299491c75acad8650107dffa9269"}, + {file = "safetensors-0.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:419010156b914a3e5da4e4adf992bee050924d0fe423c4b329e523e2c14c3547"}, + {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88f6fd5a5c1302ce79993cc5feeadcc795a70f953c762544d01fb02b2db4ea33"}, + {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d468cffb82d90789696d5b4d8b6ab8843052cba58a15296691a7a3df55143cd2"}, + {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9353c2af2dd467333d4850a16edb66855e795561cd170685178f706c80d2c71e"}, + {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:83c155b4a33368d9b9c2543e78f2452090fb030c52401ca608ef16fa58c98353"}, + {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9850754c434e636ce3dc586f534bb23bcbd78940c304775bee9005bf610e98f1"}, + {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:275f500b4d26f67b6ec05629a4600645231bd75e4ed42087a7c1801bff04f4b3"}, + {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5c2308de665b7130cd0e40a2329278226e4cf083f7400c51ca7e19ccfb3886f3"}, + {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e06a9ebc8656e030ccfe44634f2a541b4b1801cd52e390a53ad8bacbd65f8518"}, + {file = "safetensors-0.4.4-cp39-none-win32.whl", hash = "sha256:ef73df487b7c14b477016947c92708c2d929e1dee2bacdd6fff5a82ed4539537"}, + {file = "safetensors-0.4.4-cp39-none-win_amd64.whl", hash = "sha256:83d054818a8d1198d8bd8bc3ea2aac112a2c19def2bf73758321976788706398"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1d1f34c71371f0e034004a0b583284b45d233dd0b5f64a9125e16b8a01d15067"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a8043a33d58bc9b30dfac90f75712134ca34733ec3d8267b1bd682afe7194f5"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8db8f0c59c84792c12661f8efa85de160f80efe16b87a9d5de91b93f9e0bce3c"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfc1fc38e37630dd12d519bdec9dcd4b345aec9930bb9ce0ed04461f49e58b52"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5c9d86d9b13b18aafa88303e2cd21e677f5da2a14c828d2c460fe513af2e9a5"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:43251d7f29a59120a26f5a0d9583b9e112999e500afabcfdcb91606d3c5c89e3"}, + {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2c42e9b277513b81cf507e6121c7b432b3235f980cac04f39f435b7902857f91"}, + {file = "safetensors-0.4.4-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3daacc9a4e3f428a84dd56bf31f20b768eb0b204af891ed68e1f06db9edf546f"}, + {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:218bbb9b883596715fc9997bb42470bf9f21bb832c3b34c2bf744d6fa8f2bbba"}, + {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bd5efc26b39f7fc82d4ab1d86a7f0644c8e34f3699c33f85bfa9a717a030e1b"}, + {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56ad9776b65d8743f86698a1973292c966cf3abff627efc44ed60e66cc538ddd"}, + {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:30f23e6253c5f43a809dea02dc28a9f5fa747735dc819f10c073fe1b605e97d4"}, + {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:5512078d00263de6cb04e9d26c9ae17611098f52357fea856213e38dc462f81f"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b96c3d9266439d17f35fc2173111d93afc1162f168e95aed122c1ca517b1f8f1"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:08d464aa72a9a13826946b4fb9094bb4b16554bbea2e069e20bd903289b6ced9"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:210160816d5a36cf41f48f38473b6f70d7bcb4b0527bedf0889cc0b4c3bb07db"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb276a53717f2bcfb6df0bcf284d8a12069002508d4c1ca715799226024ccd45"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a2c28c6487f17d8db0089e8b2cdc13de859366b94cc6cdc50e1b0a4147b56551"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7915f0c60e4e6e65d90f136d85dd3b429ae9191c36b380e626064694563dbd9f"}, + {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:00eea99ae422fbfa0b46065acbc58b46bfafadfcec179d4b4a32d5c45006af6c"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bb1ed4fcb0b3c2f3ea2c5767434622fe5d660e5752f21ac2e8d737b1e5e480bb"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:73fc9a0a4343188bdb421783e600bfaf81d0793cd4cce6bafb3c2ed567a74cd5"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c37e6b714200824c73ca6eaf007382de76f39466a46e97558b8dc4cf643cfbf"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75698c5c5c542417ac4956acfc420f7d4a2396adca63a015fd66641ea751759"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca1a209157f242eb183e209040097118472e169f2e069bfbd40c303e24866543"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:177f2b60a058f92a3cec7a1786c9106c29eca8987ecdfb79ee88126e5f47fa31"}, + {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ee9622e84fe6e4cd4f020e5fda70d6206feff3157731df7151d457fdae18e541"}, + {file = "safetensors-0.4.4.tar.gz", hash = "sha256:5fe3e9b705250d0172ed4e100a811543108653fb2b66b9e702a088ad03772a07"}, ] [package.extras] @@ -7529,36 +7994,44 @@ tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc ( [[package]] name = "scipy" -version = "1.14.0" +version = "1.14.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"}, - {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"}, - {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"}, - {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"}, - {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"}, - {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"}, - {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"}, - {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, ] [package.dependencies] @@ -7566,8 +8039,8 @@ numpy = ">=1.23.5,<2.3" [package.extras] dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] -test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "sentry-sdk" @@ -7621,19 +8094,19 @@ tornado = ["tornado (>=5)"] [[package]] name = "setuptools" -version = "71.1.0" +version = "73.0.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-71.1.0-py3-none-any.whl", hash = "sha256:33874fdc59b3188304b2e7c80d9029097ea31627180896fb549c578ceb8a0855"}, - {file = "setuptools-71.1.0.tar.gz", hash = "sha256:032d42ee9fb536e33087fb66cac5f840eb9391ed05637b3f2a76a7c8fb477936"}, + {file = "setuptools-73.0.1-py3-none-any.whl", hash = "sha256:b208925fcb9f7af924ed2dc04708ea89791e24bde0d3020b27df0e116088b34e"}, + {file = "setuptools-73.0.1.tar.gz", hash = "sha256:d59a3e788ab7e012ab2c4baed1b376da6366883ee20d7a5fc426816e3d7b1193"}, ] [package.extras] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] [[package]] name = "sgmllib3k" @@ -7647,47 +8120,53 @@ files = [ [[package]] name = "shapely" -version = "2.0.5" +version = "2.0.6" description = "Manipulation and analysis of geometric objects" optional = false python-versions = ">=3.7" files = [ - {file = "shapely-2.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89d34787c44f77a7d37d55ae821f3a784fa33592b9d217a45053a93ade899375"}, - {file = "shapely-2.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:798090b426142df2c5258779c1d8d5734ec6942f778dab6c6c30cfe7f3bf64ff"}, - {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45211276900c4790d6bfc6105cbf1030742da67594ea4161a9ce6812a6721e68"}, - {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e119444bc27ca33e786772b81760f2028d930ac55dafe9bc50ef538b794a8e1"}, - {file = "shapely-2.0.5-cp310-cp310-win32.whl", hash = "sha256:9a4492a2b2ccbeaebf181e7310d2dfff4fdd505aef59d6cb0f217607cb042fb3"}, - {file = "shapely-2.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:1e5cb5ee72f1bc7ace737c9ecd30dc174a5295fae412972d3879bac2e82c8fae"}, - {file = "shapely-2.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5bbfb048a74cf273db9091ff3155d373020852805a37dfc846ab71dde4be93ec"}, - {file = "shapely-2.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93be600cbe2fbaa86c8eb70656369f2f7104cd231f0d6585c7d0aa555d6878b8"}, - {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8e71bb9a46814019f6644c4e2560a09d44b80100e46e371578f35eaaa9da1c"}, - {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5251c28a29012e92de01d2e84f11637eb1d48184ee8f22e2df6c8c578d26760"}, - {file = "shapely-2.0.5-cp311-cp311-win32.whl", hash = "sha256:35110e80070d664781ec7955c7de557456b25727a0257b354830abb759bf8311"}, - {file = "shapely-2.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c6b78c0007a34ce7144f98b7418800e0a6a5d9a762f2244b00ea560525290c9"}, - {file = "shapely-2.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:03bd7b5fa5deb44795cc0a503999d10ae9d8a22df54ae8d4a4cd2e8a93466195"}, - {file = "shapely-2.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ff9521991ed9e201c2e923da014e766c1aa04771bc93e6fe97c27dcf0d40ace"}, - {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b65365cfbf657604e50d15161ffcc68de5cdb22a601bbf7823540ab4918a98d"}, - {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21f64e647a025b61b19585d2247137b3a38a35314ea68c66aaf507a1c03ef6fe"}, - {file = "shapely-2.0.5-cp312-cp312-win32.whl", hash = "sha256:3ac7dc1350700c139c956b03d9c3df49a5b34aaf91d024d1510a09717ea39199"}, - {file = "shapely-2.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:30e8737983c9d954cd17feb49eb169f02f1da49e24e5171122cf2c2b62d65c95"}, - {file = "shapely-2.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ff7731fea5face9ec08a861ed351734a79475631b7540ceb0b66fb9732a5f529"}, - {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff9e520af0c5a578e174bca3c18713cd47a6c6a15b6cf1f50ac17dc8bb8db6a2"}, - {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49b299b91557b04acb75e9732645428470825061f871a2edc36b9417d66c1fc5"}, - {file = "shapely-2.0.5-cp37-cp37m-win32.whl", hash = "sha256:b5870633f8e684bf6d1ae4df527ddcb6f3895f7b12bced5c13266ac04f47d231"}, - {file = "shapely-2.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:401cb794c5067598f50518e5a997e270cd7642c4992645479b915c503866abed"}, - {file = "shapely-2.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e91ee179af539100eb520281ba5394919067c6b51824e6ab132ad4b3b3e76dd0"}, - {file = "shapely-2.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8af6f7260f809c0862741ad08b1b89cb60c130ae30efab62320bbf4ee9cc71fa"}, - {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5456dd522800306ba3faef77c5ba847ec30a0bd73ab087a25e0acdd4db2514f"}, - {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b714a840402cde66fd7b663bb08cacb7211fa4412ea2a209688f671e0d0631fd"}, - {file = "shapely-2.0.5-cp38-cp38-win32.whl", hash = "sha256:7e8cf5c252fac1ea51b3162be2ec3faddedc82c256a1160fc0e8ddbec81b06d2"}, - {file = "shapely-2.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:4461509afdb15051e73ab178fae79974387f39c47ab635a7330d7fee02c68a3f"}, - {file = "shapely-2.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7545a39c55cad1562be302d74c74586f79e07b592df8ada56b79a209731c0219"}, - {file = "shapely-2.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4c83a36f12ec8dee2066946d98d4d841ab6512a6ed7eb742e026a64854019b5f"}, - {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89e640c2cd37378480caf2eeda9a51be64201f01f786d127e78eaeff091ec897"}, - {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06efe39beafde3a18a21dde169d32f315c57da962826a6d7d22630025200c5e6"}, - {file = "shapely-2.0.5-cp39-cp39-win32.whl", hash = "sha256:8203a8b2d44dcb366becbc8c3d553670320e4acf0616c39e218c9561dd738d92"}, - {file = "shapely-2.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:7fed9dbfbcfec2682d9a047b9699db8dcc890dfca857ecba872c42185fc9e64e"}, - {file = "shapely-2.0.5.tar.gz", hash = "sha256:bff2366bc786bfa6cb353d6b47d0443c570c32776612e527ee47b6df63fcfe32"}, + {file = "shapely-2.0.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29a34e068da2d321e926b5073539fd2a1d4429a2c656bd63f0bd4c8f5b236d0b"}, + {file = "shapely-2.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c84c3f53144febf6af909d6b581bc05e8785d57e27f35ebaa5c1ab9baba13b"}, + {file = "shapely-2.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ad2fae12dca8d2b727fa12b007e46fbc522148a584f5d6546c539f3464dccde"}, + {file = "shapely-2.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3304883bd82d44be1b27a9d17f1167fda8c7f5a02a897958d86c59ec69b705e"}, + {file = "shapely-2.0.6-cp310-cp310-win32.whl", hash = "sha256:3ec3a0eab496b5e04633a39fa3d5eb5454628228201fb24903d38174ee34565e"}, + {file = "shapely-2.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:28f87cdf5308a514763a5c38de295544cb27429cfa655d50ed8431a4796090c4"}, + {file = "shapely-2.0.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5aeb0f51a9db176da9a30cb2f4329b6fbd1e26d359012bb0ac3d3c7781667a9e"}, + {file = "shapely-2.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a7a78b0d51257a367ee115f4d41ca4d46edbd0dd280f697a8092dd3989867b2"}, + {file = "shapely-2.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32c23d2f43d54029f986479f7c1f6e09c6b3a19353a3833c2ffb226fb63a855"}, + {file = "shapely-2.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3dc9fb0eb56498912025f5eb352b5126f04801ed0e8bdbd867d21bdbfd7cbd0"}, + {file = "shapely-2.0.6-cp311-cp311-win32.whl", hash = "sha256:d93b7e0e71c9f095e09454bf18dad5ea716fb6ced5df3cb044564a00723f339d"}, + {file = "shapely-2.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:c02eb6bf4cfb9fe6568502e85bb2647921ee49171bcd2d4116c7b3109724ef9b"}, + {file = "shapely-2.0.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cec9193519940e9d1b86a3b4f5af9eb6910197d24af02f247afbfb47bcb3fab0"}, + {file = "shapely-2.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83b94a44ab04a90e88be69e7ddcc6f332da7c0a0ebb1156e1c4f568bbec983c3"}, + {file = "shapely-2.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:537c4b2716d22c92036d00b34aac9d3775e3691f80c7aa517c2c290351f42cd8"}, + {file = "shapely-2.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98fea108334be345c283ce74bf064fa00cfdd718048a8af7343c59eb40f59726"}, + {file = "shapely-2.0.6-cp312-cp312-win32.whl", hash = "sha256:42fd4cd4834747e4990227e4cbafb02242c0cffe9ce7ef9971f53ac52d80d55f"}, + {file = "shapely-2.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:665990c84aece05efb68a21b3523a6b2057e84a1afbef426ad287f0796ef8a48"}, + {file = "shapely-2.0.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:42805ef90783ce689a4dde2b6b2f261e2c52609226a0438d882e3ced40bb3013"}, + {file = "shapely-2.0.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6d2cb146191a47bd0cee8ff5f90b47547b82b6345c0d02dd8b25b88b68af62d7"}, + {file = "shapely-2.0.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3fdef0a1794a8fe70dc1f514440aa34426cc0ae98d9a1027fb299d45741c381"}, + {file = "shapely-2.0.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c665a0301c645615a107ff7f52adafa2153beab51daf34587170d85e8ba6805"}, + {file = "shapely-2.0.6-cp313-cp313-win32.whl", hash = "sha256:0334bd51828f68cd54b87d80b3e7cee93f249d82ae55a0faf3ea21c9be7b323a"}, + {file = "shapely-2.0.6-cp313-cp313-win_amd64.whl", hash = "sha256:d37d070da9e0e0f0a530a621e17c0b8c3c9d04105655132a87cfff8bd77cc4c2"}, + {file = "shapely-2.0.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fa7468e4f5b92049c0f36d63c3e309f85f2775752e076378e36c6387245c5462"}, + {file = "shapely-2.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed5867e598a9e8ac3291da6cc9baa62ca25706eea186117034e8ec0ea4355653"}, + {file = "shapely-2.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81d9dfe155f371f78c8d895a7b7f323bb241fb148d848a2bf2244f79213123fe"}, + {file = "shapely-2.0.6-cp37-cp37m-win32.whl", hash = "sha256:fbb7bf02a7542dba55129062570211cfb0defa05386409b3e306c39612e7fbcc"}, + {file = "shapely-2.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:837d395fac58aa01aa544495b97940995211e3e25f9aaf87bc3ba5b3a8cd1ac7"}, + {file = "shapely-2.0.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c6d88ade96bf02f6bfd667ddd3626913098e243e419a0325ebef2bbd481d1eb6"}, + {file = "shapely-2.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8b3b818c4407eaa0b4cb376fd2305e20ff6df757bf1356651589eadc14aab41b"}, + {file = "shapely-2.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bbc783529a21f2bd50c79cef90761f72d41c45622b3e57acf78d984c50a5d13"}, + {file = "shapely-2.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2423f6c0903ebe5df6d32e0066b3d94029aab18425ad4b07bf98c3972a6e25a1"}, + {file = "shapely-2.0.6-cp38-cp38-win32.whl", hash = "sha256:2de00c3bfa80d6750832bde1d9487e302a6dd21d90cb2f210515cefdb616e5f5"}, + {file = "shapely-2.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:3a82d58a1134d5e975f19268710e53bddd9c473743356c90d97ce04b73e101ee"}, + {file = "shapely-2.0.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:392f66f458a0a2c706254f473290418236e52aa4c9b476a072539d63a2460595"}, + {file = "shapely-2.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eba5bae271d523c938274c61658ebc34de6c4b33fdf43ef7e938b5776388c1be"}, + {file = "shapely-2.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7060566bc4888b0c8ed14b5d57df8a0ead5c28f9b69fb6bed4476df31c51b0af"}, + {file = "shapely-2.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b02154b3e9d076a29a8513dffcb80f047a5ea63c897c0cd3d3679f29363cf7e5"}, + {file = "shapely-2.0.6-cp39-cp39-win32.whl", hash = "sha256:44246d30124a4f1a638a7d5419149959532b99dfa25b54393512e6acc9c211ac"}, + {file = "shapely-2.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:2b542d7f1dbb89192d3512c52b679c822ba916f93479fa5d4fc2fe4fa0b3c9e8"}, + {file = "shapely-2.0.6.tar.gz", hash = "sha256:997f6159b1484059ec239cacaa53467fd8b5564dabe186cd84ac2944663b0bf6"}, ] [package.dependencies] @@ -7760,71 +8239,71 @@ files = [ [[package]] name = "soupsieve" -version = "2.5" +version = "2.6" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] [[package]] name = "sqlalchemy" -version = "2.0.31" +version = "2.0.32" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.31-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f2a213c1b699d3f5768a7272de720387ae0122f1becf0901ed6eaa1abd1baf6c"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9fea3d0884e82d1e33226935dac990b967bef21315cbcc894605db3441347443"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ad7f221d8a69d32d197e5968d798217a4feebe30144986af71ada8c548e9fa"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f2bee229715b6366f86a95d497c347c22ddffa2c7c96143b59a2aa5cc9eebbc"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cd5b94d4819c0c89280b7c6109c7b788a576084bf0a480ae17c227b0bc41e109"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:750900a471d39a7eeba57580b11983030517a1f512c2cb287d5ad0fcf3aebd58"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-win32.whl", hash = "sha256:7bd112be780928c7f493c1a192cd8c5fc2a2a7b52b790bc5a84203fb4381c6be"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-win_amd64.whl", hash = "sha256:5a48ac4d359f058474fadc2115f78a5cdac9988d4f99eae44917f36aa1476327"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f68470edd70c3ac3b6cd5c2a22a8daf18415203ca1b036aaeb9b0fb6f54e8298"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e2c38c2a4c5c634fe6c3c58a789712719fa1bf9b9d6ff5ebfce9a9e5b89c1ca"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd15026f77420eb2b324dcb93551ad9c5f22fab2c150c286ef1dc1160f110203"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2196208432deebdfe3b22185d46b08f00ac9d7b01284e168c212919891289396"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:352b2770097f41bff6029b280c0e03b217c2dcaddc40726f8f53ed58d8a85da4"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56d51ae825d20d604583f82c9527d285e9e6d14f9a5516463d9705dab20c3740"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-win32.whl", hash = "sha256:6e2622844551945db81c26a02f27d94145b561f9d4b0c39ce7bfd2fda5776dac"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-win_amd64.whl", hash = "sha256:ccaf1b0c90435b6e430f5dd30a5aede4764942a695552eb3a4ab74ed63c5b8d3"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3b74570d99126992d4b0f91fb87c586a574a5872651185de8297c6f90055ae42"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f77c4f042ad493cb8595e2f503c7a4fe44cd7bd59c7582fd6d78d7e7b8ec52c"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd1591329333daf94467e699e11015d9c944f44c94d2091f4ac493ced0119449"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74afabeeff415e35525bf7a4ecdab015f00e06456166a2eba7590e49f8db940e"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b9c01990d9015df2c6f818aa8f4297d42ee71c9502026bb074e713d496e26b67"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:66f63278db425838b3c2b1c596654b31939427016ba030e951b292e32b99553e"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-win32.whl", hash = "sha256:0b0f658414ee4e4b8cbcd4a9bb0fd743c5eeb81fc858ca517217a8013d282c96"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-win_amd64.whl", hash = "sha256:fa4b1af3e619b5b0b435e333f3967612db06351217c58bfb50cee5f003db2a5a"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f43e93057cf52a227eda401251c72b6fbe4756f35fa6bfebb5d73b86881e59b0"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d337bf94052856d1b330d5fcad44582a30c532a2463776e1651bd3294ee7e58b"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c06fb43a51ccdff3b4006aafee9fcf15f63f23c580675f7734245ceb6b6a9e05"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:b6e22630e89f0e8c12332b2b4c282cb01cf4da0d26795b7eae16702a608e7ca1"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:79a40771363c5e9f3a77f0e28b3302801db08040928146e6808b5b7a40749c88"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-win32.whl", hash = "sha256:501ff052229cb79dd4c49c402f6cb03b5a40ae4771efc8bb2bfac9f6c3d3508f"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-win_amd64.whl", hash = "sha256:597fec37c382a5442ffd471f66ce12d07d91b281fd474289356b1a0041bdf31d"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dc6d69f8829712a4fd799d2ac8d79bdeff651c2301b081fd5d3fe697bd5b4ab9"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:23b9fbb2f5dd9e630db70fbe47d963c7779e9c81830869bd7d137c2dc1ad05fb"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a21c97efcbb9f255d5c12a96ae14da873233597dfd00a3a0c4ce5b3e5e79704"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26a6a9837589c42b16693cf7bf836f5d42218f44d198f9343dd71d3164ceeeac"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc251477eae03c20fae8db9c1c23ea2ebc47331bcd73927cdcaecd02af98d3c3"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2fd17e3bb8058359fa61248c52c7b09a97cf3c820e54207a50af529876451808"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-win32.whl", hash = "sha256:c76c81c52e1e08f12f4b6a07af2b96b9b15ea67ccdd40ae17019f1c373faa227"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-win_amd64.whl", hash = "sha256:4b600e9a212ed59355813becbcf282cfda5c93678e15c25a0ef896b354423238"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b6cf796d9fcc9b37011d3f9936189b3c8074a02a4ed0c0fbbc126772c31a6d4"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:78fe11dbe37d92667c2c6e74379f75746dc947ee505555a0197cfba9a6d4f1a4"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fc47dc6185a83c8100b37acda27658fe4dbd33b7d5e7324111f6521008ab4fe"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a41514c1a779e2aa9a19f67aaadeb5cbddf0b2b508843fcd7bafdf4c6864005"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:afb6dde6c11ea4525318e279cd93c8734b795ac8bb5dda0eedd9ebaca7fa23f1"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3f9faef422cfbb8fd53716cd14ba95e2ef655400235c3dfad1b5f467ba179c8c"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-win32.whl", hash = "sha256:fc6b14e8602f59c6ba893980bea96571dd0ed83d8ebb9c4479d9ed5425d562e9"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-win_amd64.whl", hash = "sha256:3cb8a66b167b033ec72c3812ffc8441d4e9f5f78f5e31e54dcd4c90a4ca5bebc"}, - {file = "SQLAlchemy-2.0.31-py3-none-any.whl", hash = "sha256:69f3e3c08867a8e4856e92d7afb618b95cdee18e0bc1647b77599722c9a28911"}, - {file = "SQLAlchemy-2.0.31.tar.gz", hash = "sha256:b607489dd4a54de56984a0c7656247504bd5523d9d0ba799aef59d4add009484"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0c9045ecc2e4db59bfc97b20516dfdf8e41d910ac6fb667ebd3a79ea54084619"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1467940318e4a860afd546ef61fefb98a14d935cd6817ed07a228c7f7c62f389"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5954463675cb15db8d4b521f3566a017c8789222b8316b1e6934c811018ee08b"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167e7497035c303ae50651b351c28dc22a40bb98fbdb8468cdc971821b1ae533"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b27dfb676ac02529fb6e343b3a482303f16e6bc3a4d868b73935b8792edb52d0"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bf2360a5e0f7bd75fa80431bf8ebcfb920c9f885e7956c7efde89031695cafb8"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-win32.whl", hash = "sha256:306fe44e754a91cd9d600a6b070c1f2fadbb4a1a257b8781ccf33c7067fd3e4d"}, + {file = "SQLAlchemy-2.0.32-cp310-cp310-win_amd64.whl", hash = "sha256:99db65e6f3ab42e06c318f15c98f59a436f1c78179e6a6f40f529c8cc7100b22"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:21b053be28a8a414f2ddd401f1be8361e41032d2ef5884b2f31d31cb723e559f"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b178e875a7a25b5938b53b006598ee7645172fccafe1c291a706e93f48499ff5"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723a40ee2cc7ea653645bd4cf024326dea2076673fc9d3d33f20f6c81db83e1d"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:295ff8689544f7ee7e819529633d058bd458c1fd7f7e3eebd0f9268ebc56c2a0"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:49496b68cd190a147118af585173ee624114dfb2e0297558c460ad7495f9dfe2"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:acd9b73c5c15f0ec5ce18128b1fe9157ddd0044abc373e6ecd5ba376a7e5d961"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-win32.whl", hash = "sha256:9365a3da32dabd3e69e06b972b1ffb0c89668994c7e8e75ce21d3e5e69ddef28"}, + {file = "SQLAlchemy-2.0.32-cp311-cp311-win_amd64.whl", hash = "sha256:8bd63d051f4f313b102a2af1cbc8b80f061bf78f3d5bd0843ff70b5859e27924"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bab3db192a0c35e3c9d1560eb8332463e29e5507dbd822e29a0a3c48c0a8d92"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:19d98f4f58b13900d8dec4ed09dd09ef292208ee44cc9c2fe01c1f0a2fe440e9"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd33c61513cb1b7371fd40cf221256456d26a56284e7d19d1f0b9f1eb7dd7e8"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6ba0497c1d066dd004e0f02a92426ca2df20fac08728d03f67f6960271feec"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b6be53e4fde0065524f1a0a7929b10e9280987b320716c1509478b712a7688c"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:916a798f62f410c0b80b63683c8061f5ebe237b0f4ad778739304253353bc1cb"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-win32.whl", hash = "sha256:31983018b74908ebc6c996a16ad3690301a23befb643093fcfe85efd292e384d"}, + {file = "SQLAlchemy-2.0.32-cp312-cp312-win_amd64.whl", hash = "sha256:4363ed245a6231f2e2957cccdda3c776265a75851f4753c60f3004b90e69bfeb"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8afd5b26570bf41c35c0121801479958b4446751a3971fb9a480c1afd85558e"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c750987fc876813f27b60d619b987b057eb4896b81117f73bb8d9918c14f1cad"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0102afff4890f651ed91120c1120065663506b760da4e7823913ebd3258be"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:78c03d0f8a5ab4f3034c0e8482cfcc415a3ec6193491cfa1c643ed707d476f16"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:3bd1cae7519283ff525e64645ebd7a3e0283f3c038f461ecc1c7b040a0c932a1"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-win32.whl", hash = "sha256:01438ebcdc566d58c93af0171c74ec28efe6a29184b773e378a385e6215389da"}, + {file = "SQLAlchemy-2.0.32-cp37-cp37m-win_amd64.whl", hash = "sha256:4979dc80fbbc9d2ef569e71e0896990bc94df2b9fdbd878290bd129b65ab579c"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c742be912f57586ac43af38b3848f7688863a403dfb220193a882ea60e1ec3a"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62e23d0ac103bcf1c5555b6c88c114089587bc64d048fef5bbdb58dfd26f96da"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:251f0d1108aab8ea7b9aadbd07fb47fb8e3a5838dde34aa95a3349876b5a1f1d"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef18a84e5116340e38eca3e7f9eeaaef62738891422e7c2a0b80feab165905f"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3eb6a97a1d39976f360b10ff208c73afb6a4de86dd2a6212ddf65c4a6a2347d5"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0c1c9b673d21477cec17ab10bc4decb1322843ba35b481585facd88203754fc5"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-win32.whl", hash = "sha256:c41a2b9ca80ee555decc605bd3c4520cc6fef9abde8fd66b1cf65126a6922d65"}, + {file = "SQLAlchemy-2.0.32-cp38-cp38-win_amd64.whl", hash = "sha256:8a37e4d265033c897892279e8adf505c8b6b4075f2b40d77afb31f7185cd6ecd"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:52fec964fba2ef46476312a03ec8c425956b05c20220a1a03703537824b5e8e1"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:328429aecaba2aee3d71e11f2477c14eec5990fb6d0e884107935f7fb6001632"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85a01b5599e790e76ac3fe3aa2f26e1feba56270023d6afd5550ed63c68552b3"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaf04784797dcdf4c0aa952c8d234fa01974c4729db55c45732520ce12dd95b4"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4488120becf9b71b3ac718f4138269a6be99a42fe023ec457896ba4f80749525"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:14e09e083a5796d513918a66f3d6aedbc131e39e80875afe81d98a03312889e6"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-win32.whl", hash = "sha256:0d322cc9c9b2154ba7e82f7bf25ecc7c36fbe2d82e2933b3642fc095a52cfc78"}, + {file = "SQLAlchemy-2.0.32-cp39-cp39-win_amd64.whl", hash = "sha256:7dd8583df2f98dea28b5cd53a1beac963f4f9d087888d75f22fcc93a07cf8d84"}, + {file = "SQLAlchemy-2.0.32-py3-none-any.whl", hash = "sha256:e567a8793a692451f706b363ccf3c45e056b67d90ead58c3bc9471af5d212202"}, + {file = "SQLAlchemy-2.0.32.tar.gz", hash = "sha256:c1b88cc8b02b6a5f0efb0345a03672d4c897dc7d92585176f88c67346f565ea8"}, ] [package.dependencies] @@ -7873,13 +8352,13 @@ doc = ["sphinx"] [[package]] name = "starlette" -version = "0.37.2" +version = "0.38.2" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, - {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, + {file = "starlette-0.38.2-py3-none-any.whl", hash = "sha256:4ec6a59df6bbafdab5f567754481657f7ed90dc9d69b0c9ff017907dd54faeff"}, + {file = "starlette-0.38.2.tar.gz", hash = "sha256:c7c0441065252160993a1a37cf2a73bb64d271b17303e0b0c1eb7191cfb12d75"}, ] [package.dependencies] @@ -7888,19 +8367,36 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "strictyaml" +version = "1.7.3" +description = "Strict, typed YAML parser" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "strictyaml-1.7.3-py3-none-any.whl", hash = "sha256:fb5c8a4edb43bebb765959e420f9b3978d7f1af88c80606c03fb420888f5d1c7"}, + {file = "strictyaml-1.7.3.tar.gz", hash = "sha256:22f854a5fcab42b5ddba8030a0e4be51ca89af0267961c8d6cfa86395586c407"}, +] + +[package.dependencies] +python-dateutil = ">=2.6.0" + [[package]] name = "sympy" -version = "1.12" +version = "1.13.2" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, - {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, + {file = "sympy-1.13.2-py3-none-any.whl", hash = "sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9"}, + {file = "sympy-1.13.2.tar.gz", hash = "sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13"}, ] [package.dependencies] -mpmath = ">=0.19" +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "tabulate" @@ -7933,13 +8429,13 @@ requests = "*" [[package]] name = "tenacity" -version = "8.3.0" +version = "9.0.0" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" files = [ - {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, - {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, ] [package.extras] @@ -7948,13 +8444,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1196" +version = "3.0.1216" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1196.tar.gz", hash = "sha256:a8acd14f7480987ff0fd1d961ad934b2b7533ab1937d7e3adb74d95dc49954bd"}, - {file = "tencentcloud_sdk_python_common-3.0.1196-py2.py3-none-any.whl", hash = "sha256:5ed438bc3e2818ca8e84b3896aaa2746798fba981bd94b27528eb36efa5b4a30"}, + {file = "tencentcloud-sdk-python-common-3.0.1216.tar.gz", hash = "sha256:7ad83b100574068fe25439fe47fd27253ff1c730348a309567a7ff88eda63cf8"}, + {file = "tencentcloud_sdk_python_common-3.0.1216-py2.py3-none-any.whl", hash = "sha256:5e1cf9b685923d567d379f96a7008084006ad68793cfa0a0524e65dc59fd09d7"}, ] [package.dependencies] @@ -7962,17 +8458,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1196" +version = "3.0.1216" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1196.tar.gz", hash = "sha256:ced26497ae5f1b8fcc6cbd12238109274251e82fa1cfedfd6700df776306a36c"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1196-py2.py3-none-any.whl", hash = "sha256:d18a19cffeaf4ff8a60670dc2bdb644f3d7ae6a51c30d21b50ded24a9c542248"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1216.tar.gz", hash = "sha256:b295d67f97dba52ed358a1d9e061f94b1a4a87e45714efbf0987edab12642206"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1216-py2.py3-none-any.whl", hash = "sha256:62d925b41424017929b532389061a076dca72dde455e85ec089947645010e691"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1196" +tencentcloud-sdk-python-common = "3.0.1216" [[package]] name = "threadpoolctl" @@ -8236,13 +8732,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.4" +version = "4.66.5" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, - {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, + {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, + {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, ] [package.dependencies] @@ -8341,13 +8837,13 @@ requests = ">=2.0.0" [[package]] name = "typer" -version = "0.12.3" +version = "0.12.4" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"}, - {file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"}, + {file = "typer-0.12.4-py3-none-any.whl", hash = "sha256:819aa03699f438397e876aa12b0d63766864ecba1b579092cc9fe35d886e34b6"}, + {file = "typer-0.12.4.tar.gz", hash = "sha256:c9c1613ed6a166162705b3347b8d10b661ccc5d95692654d0fb628118f2c34e6"}, ] [package.dependencies] @@ -8624,13 +9120,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.30.3" +version = "0.30.6" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.30.3-py3-none-any.whl", hash = "sha256:94a3608da0e530cea8f69683aa4126364ac18e3826b6630d1a65f4638aade503"}, - {file = "uvicorn-0.30.3.tar.gz", hash = "sha256:0d114d0831ff1adbf231d358cbf42f17333413042552a624ea6a9b4c33dcfd81"}, + {file = "uvicorn-0.30.6-py3-none-any.whl", hash = "sha256:65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5"}, + {file = "uvicorn-0.30.6.tar.gz", hash = "sha256:4b15decdda1e72be08209e860a1e10e92439ad5b97cf44cc945fcbee66fc5788"}, ] [package.dependencies] @@ -8650,42 +9146,42 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [[package]] name = "uvloop" -version = "0.19.0" +version = "0.20.0" description = "Fast implementation of asyncio event loop on top of libuv" optional = false python-versions = ">=3.8.0" files = [ - {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:de4313d7f575474c8f5a12e163f6d89c0a878bc49219641d49e6f1444369a90e"}, - {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5588bd21cf1fcf06bded085f37e43ce0e00424197e7c10e77afd4bbefffef428"}, - {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b1fd71c3843327f3bbc3237bedcdb6504fd50368ab3e04d0410e52ec293f5b8"}, - {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a05128d315e2912791de6088c34136bfcdd0c7cbc1cf85fd6fd1bb321b7c849"}, - {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cd81bdc2b8219cb4b2556eea39d2e36bfa375a2dd021404f90a62e44efaaf957"}, - {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5f17766fb6da94135526273080f3455a112f82570b2ee5daa64d682387fe0dcd"}, - {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4ce6b0af8f2729a02a5d1575feacb2a94fc7b2e983868b009d51c9a9d2149bef"}, - {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:31e672bb38b45abc4f26e273be83b72a0d28d074d5b370fc4dcf4c4eb15417d2"}, - {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:570fc0ed613883d8d30ee40397b79207eedd2624891692471808a95069a007c1"}, - {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5138821e40b0c3e6c9478643b4660bd44372ae1e16a322b8fc07478f92684e24"}, - {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:91ab01c6cd00e39cde50173ba4ec68a1e578fee9279ba64f5221810a9e786533"}, - {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:47bf3e9312f63684efe283f7342afb414eea4d3011542155c7e625cd799c3b12"}, - {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:da8435a3bd498419ee8c13c34b89b5005130a476bda1d6ca8cfdde3de35cd650"}, - {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:02506dc23a5d90e04d4f65c7791e65cf44bd91b37f24cfc3ef6cf2aff05dc7ec"}, - {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2693049be9d36fef81741fddb3f441673ba12a34a704e7b4361efb75cf30befc"}, - {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7010271303961c6f0fe37731004335401eb9075a12680738731e9c92ddd96ad6"}, - {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5daa304d2161d2918fa9a17d5635099a2f78ae5b5960e742b2fcfbb7aefaa593"}, - {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:7207272c9520203fea9b93843bb775d03e1cf88a80a936ce760f60bb5add92f3"}, - {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:78ab247f0b5671cc887c31d33f9b3abfb88d2614b84e4303f1a63b46c046c8bd"}, - {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:472d61143059c84947aa8bb74eabbace30d577a03a1805b77933d6bd13ddebbd"}, - {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45bf4c24c19fb8a50902ae37c5de50da81de4922af65baf760f7c0c42e1088be"}, - {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271718e26b3e17906b28b67314c45d19106112067205119dddbd834c2b7ce797"}, - {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:34175c9fd2a4bc3adc1380e1261f60306344e3407c20a4d684fd5f3be010fa3d"}, - {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e27f100e1ff17f6feeb1f33968bc185bf8ce41ca557deee9d9bbbffeb72030b7"}, - {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13dfdf492af0aa0a0edf66807d2b465607d11c4fa48f4a1fd41cbea5b18e8e8b"}, - {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e3d4e85ac060e2342ff85e90d0c04157acb210b9ce508e784a944f852a40e67"}, - {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca4956c9ab567d87d59d49fa3704cf29e37109ad348f2d5223c9bf761a332e7"}, - {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f467a5fd23b4fc43ed86342641f3936a68ded707f4627622fa3f82a120e18256"}, - {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:492e2c32c2af3f971473bc22f086513cedfc66a130756145a931a90c3958cb17"}, - {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2df95fca285a9f5bfe730e51945ffe2fa71ccbfdde3b0da5772b4ee4f2e770d5"}, - {file = "uvloop-0.19.0.tar.gz", hash = "sha256:0246f4fd1bf2bf702e06b0d45ee91677ee5c31242f39aab4ea6fe0c51aedd0fd"}, + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:9ebafa0b96c62881d5cafa02d9da2e44c23f9f0cd829f3a32a6aff771449c996"}, + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:35968fc697b0527a06e134999eef859b4034b37aebca537daeb598b9d45a137b"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b16696f10e59d7580979b420eedf6650010a4a9c3bd8113f24a103dfdb770b10"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b04d96188d365151d1af41fa2d23257b674e7ead68cfd61c725a422764062ae"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94707205efbe809dfa3a0d09c08bef1352f5d3d6612a506f10a319933757c006"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89e8d33bb88d7263f74dc57d69f0063e06b5a5ce50bb9a6b32f5fcbe655f9e73"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e50289c101495e0d1bb0bfcb4a60adde56e32f4449a67216a1ab2750aa84f037"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e237f9c1e8a00e7d9ddaa288e535dc337a39bcbf679f290aee9d26df9e72bce9"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746242cd703dc2b37f9d8b9f173749c15e9a918ddb021575a0205ec29a38d31e"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82edbfd3df39fb3d108fc079ebc461330f7c2e33dbd002d146bf7c445ba6e756"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:80dc1b139516be2077b3e57ce1cb65bfed09149e1d175e0478e7a987863b68f0"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f44af67bf39af25db4c1ac27e82e9665717f9c26af2369c404be865c8818dcf"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4b75f2950ddb6feed85336412b9a0c310a2edbcf4cf931aa5cfe29034829676d"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:77fbc69c287596880ecec2d4c7a62346bef08b6209749bf6ce8c22bbaca0239e"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6462c95f48e2d8d4c993a2950cd3d31ab061864d1c226bbf0ee2f1a8f36674b9"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649c33034979273fa71aa25d0fe120ad1777c551d8c4cd2c0c9851d88fcb13ab"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3a609780e942d43a275a617c0839d85f95c334bad29c4c0918252085113285b5"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aea15c78e0d9ad6555ed201344ae36db5c63d428818b4b2a42842b3870127c00"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0e94b221295b5e69de57a1bd4aeb0b3a29f61be6e1b478bb8a69a73377db7ba"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fee6044b64c965c425b65a4e17719953b96e065c5b7e09b599ff332bb2744bdf"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:265a99a2ff41a0fd56c19c3838b29bf54d1d177964c300dad388b27e84fd7847"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10c2956efcecb981bf9cfb8184d27d5d64b9033f917115a960b83f11bfa0d6b"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e7d61fe8e8d9335fac1bf8d5d82820b4808dd7a43020c149b63a1ada953d48a6"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2beee18efd33fa6fdb0976e18475a4042cd31c7433c866e8a09ab604c7c22ff2"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d8c36fdf3e02cec92aed2d44f63565ad1522a499c654f07935c8f9d04db69e95"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0fac7be202596c7126146660725157d4813aa29a4cc990fe51346f75ff8fde7"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d0fba61846f294bce41eb44d60d58136090ea2b5b99efd21cbdf4e21927c56a"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95720bae002ac357202e0d866128eb1ac82545bcf0b549b9abe91b5178d9b541"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:36c530d8fa03bfa7085af54a48f2ca16ab74df3ec7108a46ba82fd8b411a2315"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e97152983442b499d7a71e44f29baa75b3b02e65d9c44ba53b10338e98dedb66"}, + {file = "uvloop-0.20.0.tar.gz", hash = "sha256:4603ca714a754fc8d9b197e325db25b2ea045385e8a3ad05d3463de725fdf469"}, ] [package.extras] @@ -8765,88 +9261,122 @@ files = [ {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, ] +[[package]] +name = "volcengine-python-sdk" +version = "1.0.98" +description = "Volcengine SDK for Python" +optional = false +python-versions = "*" +files = [ + {file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"}, +] + +[package.dependencies] +anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""} +certifi = ">=2017.4.17" +httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""} +pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""} +python-dateutil = ">=2.1" +six = ">=1.10" +urllib3 = ">=1.23" + +[package.extras] +ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"] + [[package]] name = "watchfiles" -version = "0.22.0" +version = "0.23.0" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.8" files = [ - {file = "watchfiles-0.22.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:da1e0a8caebf17976e2ffd00fa15f258e14749db5e014660f53114b676e68538"}, - {file = "watchfiles-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:61af9efa0733dc4ca462347becb82e8ef4945aba5135b1638bfc20fad64d4f0e"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9188979a58a096b6f8090e816ccc3f255f137a009dd4bbec628e27696d67c1"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2bdadf6b90c099ca079d468f976fd50062905d61fae183f769637cb0f68ba59a"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:067dea90c43bf837d41e72e546196e674f68c23702d3ef80e4e816937b0a3ffd"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbf8a20266136507abf88b0df2328e6a9a7c7309e8daff124dda3803306a9fdb"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1235c11510ea557fe21be5d0e354bae2c655a8ee6519c94617fe63e05bca4171"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2444dc7cb9d8cc5ab88ebe792a8d75709d96eeef47f4c8fccb6df7c7bc5be71"}, - {file = "watchfiles-0.22.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c5af2347d17ab0bd59366db8752d9e037982e259cacb2ba06f2c41c08af02c39"}, - {file = "watchfiles-0.22.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9624a68b96c878c10437199d9a8b7d7e542feddda8d5ecff58fdc8e67b460848"}, - {file = "watchfiles-0.22.0-cp310-none-win32.whl", hash = "sha256:4b9f2a128a32a2c273d63eb1fdbf49ad64852fc38d15b34eaa3f7ca2f0d2b797"}, - {file = "watchfiles-0.22.0-cp310-none-win_amd64.whl", hash = "sha256:2627a91e8110b8de2406d8b2474427c86f5a62bf7d9ab3654f541f319ef22bcb"}, - {file = "watchfiles-0.22.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8c39987a1397a877217be1ac0fb1d8b9f662c6077b90ff3de2c05f235e6a8f96"}, - {file = "watchfiles-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a927b3034d0672f62fb2ef7ea3c9fc76d063c4b15ea852d1db2dc75fe2c09696"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:052d668a167e9fc345c24203b104c313c86654dd6c0feb4b8a6dfc2462239249"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e45fb0d70dda1623a7045bd00c9e036e6f1f6a85e4ef2c8ae602b1dfadf7550"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c49b76a78c156979759d759339fb62eb0549515acfe4fd18bb151cc07366629c"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4a65474fd2b4c63e2c18ac67a0c6c66b82f4e73e2e4d940f837ed3d2fd9d4da"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc0cba54f47c660d9fa3218158b8963c517ed23bd9f45fe463f08262a4adae1"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94ebe84a035993bb7668f58a0ebf998174fb723a39e4ef9fce95baabb42b787f"}, - {file = "watchfiles-0.22.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e0f0a874231e2839abbf473256efffe577d6ee2e3bfa5b540479e892e47c172d"}, - {file = "watchfiles-0.22.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:213792c2cd3150b903e6e7884d40660e0bcec4465e00563a5fc03f30ea9c166c"}, - {file = "watchfiles-0.22.0-cp311-none-win32.whl", hash = "sha256:b44b70850f0073b5fcc0b31ede8b4e736860d70e2dbf55701e05d3227a154a67"}, - {file = "watchfiles-0.22.0-cp311-none-win_amd64.whl", hash = "sha256:00f39592cdd124b4ec5ed0b1edfae091567c72c7da1487ae645426d1b0ffcad1"}, - {file = "watchfiles-0.22.0-cp311-none-win_arm64.whl", hash = "sha256:3218a6f908f6a276941422b035b511b6d0d8328edd89a53ae8c65be139073f84"}, - {file = "watchfiles-0.22.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c7b978c384e29d6c7372209cbf421d82286a807bbcdeb315427687f8371c340a"}, - {file = "watchfiles-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd4c06100bce70a20c4b81e599e5886cf504c9532951df65ad1133e508bf20be"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:425440e55cd735386ec7925f64d5dde392e69979d4c8459f6bb4e920210407f2"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68fe0c4d22332d7ce53ad094622b27e67440dacefbaedd29e0794d26e247280c"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8a31bfd98f846c3c284ba694c6365620b637debdd36e46e1859c897123aa232"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc2e8fe41f3cac0660197d95216c42910c2b7e9c70d48e6d84e22f577d106fc1"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b7cc10261c2786c41d9207193a85c1db1b725cf87936df40972aab466179b6"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28585744c931576e535860eaf3f2c0ec7deb68e3b9c5a85ca566d69d36d8dd27"}, - {file = "watchfiles-0.22.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00095dd368f73f8f1c3a7982a9801190cc88a2f3582dd395b289294f8975172b"}, - {file = "watchfiles-0.22.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:52fc9b0dbf54d43301a19b236b4a4614e610605f95e8c3f0f65c3a456ffd7d35"}, - {file = "watchfiles-0.22.0-cp312-none-win32.whl", hash = "sha256:581f0a051ba7bafd03e17127735d92f4d286af941dacf94bcf823b101366249e"}, - {file = "watchfiles-0.22.0-cp312-none-win_amd64.whl", hash = "sha256:aec83c3ba24c723eac14225194b862af176d52292d271c98820199110e31141e"}, - {file = "watchfiles-0.22.0-cp312-none-win_arm64.whl", hash = "sha256:c668228833c5619f6618699a2c12be057711b0ea6396aeaece4ded94184304ea"}, - {file = "watchfiles-0.22.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d47e9ef1a94cc7a536039e46738e17cce058ac1593b2eccdede8bf72e45f372a"}, - {file = "watchfiles-0.22.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28f393c1194b6eaadcdd8f941307fc9bbd7eb567995232c830f6aef38e8a6e88"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd64f3a4db121bc161644c9e10a9acdb836853155a108c2446db2f5ae1778c3d"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2abeb79209630da981f8ebca30a2c84b4c3516a214451bfc5f106723c5f45843"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4cc382083afba7918e32d5ef12321421ef43d685b9a67cc452a6e6e18920890e"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d048ad5d25b363ba1d19f92dcf29023988524bee6f9d952130b316c5802069cb"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:103622865599f8082f03af4214eaff90e2426edff5e8522c8f9e93dc17caee13"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3e1f3cf81f1f823e7874ae563457828e940d75573c8fbf0ee66818c8b6a9099"}, - {file = "watchfiles-0.22.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8597b6f9dc410bdafc8bb362dac1cbc9b4684a8310e16b1ff5eee8725d13dcd6"}, - {file = "watchfiles-0.22.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0b04a2cbc30e110303baa6d3ddce8ca3664bc3403be0f0ad513d1843a41c97d1"}, - {file = "watchfiles-0.22.0-cp38-none-win32.whl", hash = "sha256:b610fb5e27825b570554d01cec427b6620ce9bd21ff8ab775fc3a32f28bba63e"}, - {file = "watchfiles-0.22.0-cp38-none-win_amd64.whl", hash = "sha256:fe82d13461418ca5e5a808a9e40f79c1879351fcaeddbede094028e74d836e86"}, - {file = "watchfiles-0.22.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3973145235a38f73c61474d56ad6199124e7488822f3a4fc97c72009751ae3b0"}, - {file = "watchfiles-0.22.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:280a4afbc607cdfc9571b9904b03a478fc9f08bbeec382d648181c695648202f"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a0d883351a34c01bd53cfa75cd0292e3f7e268bacf2f9e33af4ecede7e21d1d"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9165bcab15f2b6d90eedc5c20a7f8a03156b3773e5fb06a790b54ccecdb73385"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc1b9b56f051209be458b87edb6856a449ad3f803315d87b2da4c93b43a6fe72"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc1fc25a1dedf2dd952909c8e5cb210791e5f2d9bc5e0e8ebc28dd42fed7562"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc92d2d2706d2b862ce0568b24987eba51e17e14b79a1abcd2edc39e48e743c8"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97b94e14b88409c58cdf4a8eaf0e67dfd3ece7e9ce7140ea6ff48b0407a593ec"}, - {file = "watchfiles-0.22.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:96eec15e5ea7c0b6eb5bfffe990fc7c6bd833acf7e26704eb18387fb2f5fd087"}, - {file = "watchfiles-0.22.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:28324d6b28bcb8d7c1041648d7b63be07a16db5510bea923fc80b91a2a6cbed6"}, - {file = "watchfiles-0.22.0-cp39-none-win32.whl", hash = "sha256:8c3e3675e6e39dc59b8fe5c914a19d30029e36e9f99468dddffd432d8a7b1c93"}, - {file = "watchfiles-0.22.0-cp39-none-win_amd64.whl", hash = "sha256:25c817ff2a86bc3de3ed2df1703e3d24ce03479b27bb4527c57e722f8554d971"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b810a2c7878cbdecca12feae2c2ae8af59bea016a78bc353c184fa1e09f76b68"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f7e1f9c5d1160d03b93fc4b68a0aeb82fe25563e12fbcdc8507f8434ab6f823c"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:030bc4e68d14bcad2294ff68c1ed87215fbd9a10d9dea74e7cfe8a17869785ab"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace7d060432acde5532e26863e897ee684780337afb775107c0a90ae8dbccfd2"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5834e1f8b71476a26df97d121c0c0ed3549d869124ed2433e02491553cb468c2"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0bc3b2f93a140df6806c8467c7f51ed5e55a931b031b5c2d7ff6132292e803d6"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fdebb655bb1ba0122402352b0a4254812717a017d2dc49372a1d47e24073795"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8e0aa0e8cc2a43561e0184c0513e291ca891db13a269d8d47cb9841ced7c71"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2f350cbaa4bb812314af5dab0eb8d538481e2e2279472890864547f3fe2281ed"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:7a74436c415843af2a769b36bf043b6ccbc0f8d784814ba3d42fc961cdb0a9dc"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00ad0bcd399503a84cc688590cdffbe7a991691314dde5b57b3ed50a41319a31"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72a44e9481afc7a5ee3291b09c419abab93b7e9c306c9ef9108cb76728ca58d2"}, - {file = "watchfiles-0.22.0.tar.gz", hash = "sha256:988e981aaab4f3955209e7e28c7794acdb690be1efa7f16f8ea5aba7ffdadacb"}, + {file = "watchfiles-0.23.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bee8ce357a05c20db04f46c22be2d1a2c6a8ed365b325d08af94358e0688eeb4"}, + {file = "watchfiles-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4ccd3011cc7ee2f789af9ebe04745436371d36afe610028921cab9f24bb2987b"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb02d41c33be667e6135e6686f1bb76104c88a312a18faa0ef0262b5bf7f1a0f"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7cf12ac34c444362f3261fb3ff548f0037ddd4c5bb85f66c4be30d2936beb3c5"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0b2c25040a3c0ce0e66c7779cc045fdfbbb8d59e5aabfe033000b42fe44b53e"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf2be4b9eece4f3da8ba5f244b9e51932ebc441c0867bd6af46a3d97eb068d6"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40cb8fa00028908211eb9f8d47744dca21a4be6766672e1ff3280bee320436f1"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f48c917ffd36ff9a5212614c2d0d585fa8b064ca7e66206fb5c095015bc8207"}, + {file = "watchfiles-0.23.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9d183e3888ada88185ab17064079c0db8c17e32023f5c278d7bf8014713b1b5b"}, + {file = "watchfiles-0.23.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9837edf328b2805346f91209b7e660f65fb0e9ca18b7459d075d58db082bf981"}, + {file = "watchfiles-0.23.0-cp310-none-win32.whl", hash = "sha256:296e0b29ab0276ca59d82d2da22cbbdb39a23eed94cca69aed274595fb3dfe42"}, + {file = "watchfiles-0.23.0-cp310-none-win_amd64.whl", hash = "sha256:4ea756e425ab2dfc8ef2a0cb87af8aa7ef7dfc6fc46c6f89bcf382121d4fff75"}, + {file = "watchfiles-0.23.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e397b64f7aaf26915bf2ad0f1190f75c855d11eb111cc00f12f97430153c2eab"}, + {file = "watchfiles-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b4ac73b02ca1824ec0a7351588241fd3953748d3774694aa7ddb5e8e46aef3e3"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130a896d53b48a1cecccfa903f37a1d87dbb74295305f865a3e816452f6e49e4"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c5e7803a65eb2d563c73230e9d693c6539e3c975ccfe62526cadde69f3fda0cf"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1aa4cc85202956d1a65c88d18c7b687b8319dbe6b1aec8969784ef7a10e7d1a"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87f889f6e58849ddb7c5d2cb19e2e074917ed1c6e3ceca50405775166492cca8"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37fd826dac84c6441615aa3f04077adcc5cac7194a021c9f0d69af20fb9fa788"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7db6e36e7a2c15923072e41ea24d9a0cf39658cb0637ecc9307b09d28827e1"}, + {file = "watchfiles-0.23.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2368c5371c17fdcb5a2ea71c5c9d49f9b128821bfee69503cc38eae00feb3220"}, + {file = "watchfiles-0.23.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:857af85d445b9ba9178db95658c219dbd77b71b8264e66836a6eba4fbf49c320"}, + {file = "watchfiles-0.23.0-cp311-none-win32.whl", hash = "sha256:1d636c8aeb28cdd04a4aa89030c4b48f8b2954d8483e5f989774fa441c0ed57b"}, + {file = "watchfiles-0.23.0-cp311-none-win_amd64.whl", hash = "sha256:46f1d8069a95885ca529645cdbb05aea5837d799965676e1b2b1f95a4206313e"}, + {file = "watchfiles-0.23.0-cp311-none-win_arm64.whl", hash = "sha256:e495ed2a7943503766c5d1ff05ae9212dc2ce1c0e30a80d4f0d84889298fa304"}, + {file = "watchfiles-0.23.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1db691bad0243aed27c8354b12d60e8e266b75216ae99d33e927ff5238d270b5"}, + {file = "watchfiles-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:62d2b18cb1edaba311fbbfe83fb5e53a858ba37cacb01e69bc20553bb70911b8"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e087e8fdf1270d000913c12e6eca44edd02aad3559b3e6b8ef00f0ce76e0636f"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dd41d5c72417b87c00b1b635738f3c283e737d75c5fa5c3e1c60cd03eac3af77"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e5f3ca0ff47940ce0a389457b35d6df601c317c1e1a9615981c474452f98de1"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6991e3a78f642368b8b1b669327eb6751439f9f7eaaa625fae67dd6070ecfa0b"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f7252f52a09f8fa5435dc82b6af79483118ce6bd51eb74e6269f05ee22a7b9f"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e01bcb8d767c58865207a6c2f2792ad763a0fe1119fb0a430f444f5b02a5ea0"}, + {file = "watchfiles-0.23.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e56fbcdd27fce061854ddec99e015dd779cae186eb36b14471fc9ae713b118c"}, + {file = "watchfiles-0.23.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bd3e2d64500a6cad28bcd710ee6269fbeb2e5320525acd0cfab5f269ade68581"}, + {file = "watchfiles-0.23.0-cp312-none-win32.whl", hash = "sha256:eb99c954291b2fad0eff98b490aa641e128fbc4a03b11c8a0086de8b7077fb75"}, + {file = "watchfiles-0.23.0-cp312-none-win_amd64.whl", hash = "sha256:dccc858372a56080332ea89b78cfb18efb945da858fabeb67f5a44fa0bcb4ebb"}, + {file = "watchfiles-0.23.0-cp312-none-win_arm64.whl", hash = "sha256:6c21a5467f35c61eafb4e394303720893066897fca937bade5b4f5877d350ff8"}, + {file = "watchfiles-0.23.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ba31c32f6b4dceeb2be04f717811565159617e28d61a60bb616b6442027fd4b9"}, + {file = "watchfiles-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:85042ab91814fca99cec4678fc063fb46df4cbb57b4835a1cc2cb7a51e10250e"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24655e8c1c9c114005c3868a3d432c8aa595a786b8493500071e6a52f3d09217"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b1a950ab299a4a78fd6369a97b8763732bfb154fdb433356ec55a5bce9515c1"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8d3c5cd327dd6ce0edfc94374fb5883d254fe78a5e9d9dfc237a1897dc73cd1"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ff785af8bacdf0be863ec0c428e3288b817e82f3d0c1d652cd9c6d509020dd0"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:02b7ba9d4557149410747353e7325010d48edcfe9d609a85cb450f17fd50dc3d"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48a1b05c0afb2cd2f48c1ed2ae5487b116e34b93b13074ed3c22ad5c743109f0"}, + {file = "watchfiles-0.23.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:109a61763e7318d9f821b878589e71229f97366fa6a5c7720687d367f3ab9eef"}, + {file = "watchfiles-0.23.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:9f8e6bb5ac007d4a4027b25f09827ed78cbbd5b9700fd6c54429278dacce05d1"}, + {file = "watchfiles-0.23.0-cp313-none-win32.whl", hash = "sha256:f46c6f0aec8d02a52d97a583782d9af38c19a29900747eb048af358a9c1d8e5b"}, + {file = "watchfiles-0.23.0-cp313-none-win_amd64.whl", hash = "sha256:f449afbb971df5c6faeb0a27bca0427d7b600dd8f4a068492faec18023f0dcff"}, + {file = "watchfiles-0.23.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:2dddc2487d33e92f8b6222b5fb74ae2cfde5e8e6c44e0248d24ec23befdc5366"}, + {file = "watchfiles-0.23.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e75695cc952e825fa3e0684a7f4a302f9128721f13eedd8dbd3af2ba450932b8"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2537ef60596511df79b91613a5bb499b63f46f01a11a81b0a2b0dedf645d0a9c"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20b423b58f5fdde704a226b598a2d78165fe29eb5621358fe57ea63f16f165c4"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b98732ec893975455708d6fc9a6daab527fc8bbe65be354a3861f8c450a632a4"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee1f5fcbf5bc33acc0be9dd31130bcba35d6d2302e4eceafafd7d9018c7755ab"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8f195338a5a7b50a058522b39517c50238358d9ad8284fd92943643144c0c03"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524fcb8d59b0dbee2c9b32207084b67b2420f6431ed02c18bd191e6c575f5c48"}, + {file = "watchfiles-0.23.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0eff099a4df36afaa0eea7a913aa64dcf2cbd4e7a4f319a73012210af4d23810"}, + {file = "watchfiles-0.23.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a8323daae27ea290ba3350c70c836c0d2b0fb47897fa3b0ca6a5375b952b90d3"}, + {file = "watchfiles-0.23.0-cp38-none-win32.whl", hash = "sha256:aafea64a3ae698695975251f4254df2225e2624185a69534e7fe70581066bc1b"}, + {file = "watchfiles-0.23.0-cp38-none-win_amd64.whl", hash = "sha256:c846884b2e690ba62a51048a097acb6b5cd263d8bd91062cd6137e2880578472"}, + {file = "watchfiles-0.23.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a753993635eccf1ecb185dedcc69d220dab41804272f45e4aef0a67e790c3eb3"}, + {file = "watchfiles-0.23.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6bb91fa4d0b392f0f7e27c40981e46dda9eb0fbc84162c7fb478fe115944f491"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1f67312efa3902a8e8496bfa9824d3bec096ff83c4669ea555c6bdd213aa516"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ca6b71dcc50d320c88fb2d88ecd63924934a8abc1673683a242a7ca7d39e781"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2aec5c29915caf08771d2507da3ac08e8de24a50f746eb1ed295584ba1820330"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1733b9bc2c8098c6bdb0ff7a3d7cb211753fecb7bd99bdd6df995621ee1a574b"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:02ff5d7bd066c6a7673b17c8879cd8ee903078d184802a7ee851449c43521bdd"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e2de19801b0eaa4c5292a223effb7cfb43904cb742c5317a0ac686ed604765"}, + {file = "watchfiles-0.23.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8ada449e22198c31fb013ae7e9add887e8d2bd2335401abd3cbc55f8c5083647"}, + {file = "watchfiles-0.23.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3af1b05361e1cc497bf1be654a664750ae61f5739e4bb094a2be86ec8c6db9b6"}, + {file = "watchfiles-0.23.0-cp39-none-win32.whl", hash = "sha256:486bda18be5d25ab5d932699ceed918f68eb91f45d018b0343e3502e52866e5e"}, + {file = "watchfiles-0.23.0-cp39-none-win_amd64.whl", hash = "sha256:d2d42254b189a346249424fb9bb39182a19289a2409051ee432fb2926bad966a"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9265cf87a5b70147bfb2fec14770ed5b11a5bb83353f0eee1c25a81af5abfe"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9f02a259fcbbb5fcfe7a0805b1097ead5ba7a043e318eef1db59f93067f0b49b"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebaebb53b34690da0936c256c1cdb0914f24fb0e03da76d185806df9328abed"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd257f98cff9c6cb39eee1a83c7c3183970d8a8d23e8cf4f47d9a21329285cee"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:aba037c1310dd108411d27b3d5815998ef0e83573e47d4219f45753c710f969f"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:a96ac14e184aa86dc43b8a22bb53854760a58b2966c2b41580de938e9bf26ed0"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11698bb2ea5e991d10f1f4f83a39a02f91e44e4bd05f01b5c1ec04c9342bf63c"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efadd40fca3a04063d40c4448c9303ce24dd6151dc162cfae4a2a060232ebdcb"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:556347b0abb4224c5ec688fc58214162e92a500323f50182f994f3ad33385dcb"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1cf7f486169986c4b9d34087f08ce56a35126600b6fef3028f19ca16d5889071"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f18de0f82c62c4197bea5ecf4389288ac755896aac734bd2cc44004c56e4ac47"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:532e1f2c491274d1333a814e4c5c2e8b92345d41b12dc806cf07aaff786beb66"}, + {file = "watchfiles-0.23.0.tar.gz", hash = "sha256:9338ade39ff24f8086bb005d16c29f8e9f19e55b18dcb04dfa26fcbc09da497b"}, ] [package.dependencies] @@ -8912,94 +9442,108 @@ test = ["websockets"] [[package]] name = "websockets" -version = "12.0" +version = "13.0" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false python-versions = ">=3.8" files = [ - {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"}, - {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"}, - {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"}, - {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"}, - {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"}, - {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"}, - {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"}, - {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"}, - {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"}, - {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"}, - {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"}, - {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"}, - {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"}, - {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"}, - {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"}, - {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"}, - {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"}, - {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"}, - {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"}, - {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"}, - {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"}, - {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"}, - {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"}, - {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"}, - {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, + {file = "websockets-13.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ad4fa707ff9e2ffee019e946257b5300a45137a58f41fbd9a4db8e684ab61528"}, + {file = "websockets-13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6fd757f313c13c34dae9f126d3ba4cf97175859c719e57c6a614b781c86b617e"}, + {file = "websockets-13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cbac2eb7ce0fac755fb983c9247c4a60c4019bcde4c0e4d167aeb17520cc7ef1"}, + {file = "websockets-13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4b83cf7354cbbc058e97b3e545dceb75b8d9cf17fd5a19db419c319ddbaaf7a"}, + {file = "websockets-13.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9202c0010c78fad1041e1c5285232b6508d3633f92825687549540a70e9e5901"}, + {file = "websockets-13.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6566e79c8c7cbea75ec450f6e1828945fc5c9a4769ceb1c7b6e22470539712"}, + {file = "websockets-13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e7fcad070dcd9ad37a09d89a4cbc2a5e3e45080b88977c0da87b3090f9f55ead"}, + {file = "websockets-13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a8f7d65358a25172db00c69bcc7df834155ee24229f560d035758fd6613111a"}, + {file = "websockets-13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:63b702fb31e3f058f946ccdfa551f4d57a06f7729c369e8815eb18643099db37"}, + {file = "websockets-13.0-cp310-cp310-win32.whl", hash = "sha256:3a20cf14ba7b482c4a1924b5e061729afb89c890ca9ed44ac4127c6c5986e424"}, + {file = "websockets-13.0-cp310-cp310-win_amd64.whl", hash = "sha256:587245f0704d0bb675f919898d7473e8827a6d578e5a122a21756ca44b811ec8"}, + {file = "websockets-13.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:06df8306c241c235075d2ae77367038e701e53bc8c1bb4f6644f4f53aa6dedd0"}, + {file = "websockets-13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:85a1f92a02f0b8c1bf02699731a70a8a74402bb3f82bee36e7768b19a8ed9709"}, + {file = "websockets-13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9ed02c604349068d46d87ef4c2012c112c791f2bec08671903a6bb2bd9c06784"}, + {file = "websockets-13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b89849171b590107f6724a7b0790736daead40926ddf47eadf998b4ff51d6414"}, + {file = "websockets-13.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:939a16849d71203628157a5e4a495da63967c744e1e32018e9b9e2689aca64d4"}, + {file = "websockets-13.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad818cdac37c0ad4c58e51cb4964eae4f18b43c4a83cb37170b0d90c31bd80cf"}, + {file = "websockets-13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cbfe82a07596a044de78bb7a62519e71690c5812c26c5f1d4b877e64e4f46309"}, + {file = "websockets-13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e07e76c49f39c5b45cbd7362b94f001ae209a3ea4905ae9a09cfd53b3c76373d"}, + {file = "websockets-13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:372f46a0096cfda23c88f7e42349a33f8375e10912f712e6b496d3a9a557290f"}, + {file = "websockets-13.0-cp311-cp311-win32.whl", hash = "sha256:376a43a4fd96725f13450d3d2e98f4f36c3525c562ab53d9a98dd2950dca9a8a"}, + {file = "websockets-13.0-cp311-cp311-win_amd64.whl", hash = "sha256:2be1382a4daa61e2f3e2be3b3c86932a8db9d1f85297feb6e9df22f391f94452"}, + {file = "websockets-13.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b5407c34776b9b77bd89a5f95eb0a34aaf91889e3f911c63f13035220eb50107"}, + {file = "websockets-13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4782ec789f059f888c1e8fdf94383d0e64b531cffebbf26dd55afd53ab487ca4"}, + {file = "websockets-13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c8feb8e19ef65c9994e652c5b0324abd657bedd0abeb946fb4f5163012c1e730"}, + {file = "websockets-13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3f3d2e20c442b58dbac593cb1e02bc02d149a86056cc4126d977ad902472e3b"}, + {file = "websockets-13.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e39d393e0ab5b8bd01717cc26f2922026050188947ff54fe6a49dc489f7750b7"}, + {file = "websockets-13.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f661a4205741bdc88ac9c2b2ec003c72cee97e4acd156eb733662ff004ba429"}, + {file = "websockets-13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:384129ad0490e06bab2b98c1da9b488acb35bb11e2464c728376c6f55f0d45f3"}, + {file = "websockets-13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:df5c0eff91f61b8205a6c9f7b255ff390cdb77b61c7b41f79ca10afcbb22b6cb"}, + {file = "websockets-13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:02cc9bb1a887dac0e08bf657c5d00aa3fac0d03215d35a599130c2034ae6663a"}, + {file = "websockets-13.0-cp312-cp312-win32.whl", hash = "sha256:d9726d2c9bd6aed8cb994d89b3910ca0079406edce3670886ec828a73e7bdd53"}, + {file = "websockets-13.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0839f35322f7b038d8adcf679e2698c3a483688cc92e3bd15ee4fb06669e9a"}, + {file = "websockets-13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:da7e501e59857e8e3e9d10586139dc196b80445a591451ca9998aafba1af5278"}, + {file = "websockets-13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a00e1e587c655749afb5b135d8d3edcfe84ec6db864201e40a882e64168610b3"}, + {file = "websockets-13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a7fbf2a8fe7556a8f4e68cb3e736884af7bf93653e79f6219f17ebb75e97d8f0"}, + {file = "websockets-13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ea9c9c7443a97ea4d84d3e4d42d0e8c4235834edae652993abcd2aff94affd7"}, + {file = "websockets-13.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35c2221b539b360203f3f9ad168e527bf16d903e385068ae842c186efb13d0ea"}, + {file = "websockets-13.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:358d37c5c431dd050ffb06b4b075505aae3f4f795d7fff9794e5ed96ce99b998"}, + {file = "websockets-13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:038e7a0f1bfafc7bf52915ab3506b7a03d1e06381e9f60440c856e8918138151"}, + {file = "websockets-13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fd038bc9e2c134847f1e0ce3191797fad110756e690c2fdd9702ed34e7a43abb"}, + {file = "websockets-13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93b8c2008f372379fb6e5d2b3f7c9ec32f7b80316543fd3a5ace6610c5cde1b0"}, + {file = "websockets-13.0-cp313-cp313-win32.whl", hash = "sha256:851fd0afb3bc0b73f7c5b5858975d42769a5fdde5314f4ef2c106aec63100687"}, + {file = "websockets-13.0-cp313-cp313-win_amd64.whl", hash = "sha256:7d14901fdcf212804970c30ab9ee8f3f0212e620c7ea93079d6534863444fb4e"}, + {file = "websockets-13.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae7a519a56a714f64c3445cabde9fc2fc927e7eae44f413eae187cddd9e54178"}, + {file = "websockets-13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5575031472ca87302aeb2ce2c2349f4c6ea978c86a9d1289bc5d16058ad4c10a"}, + {file = "websockets-13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9895df6cd0bfe79d09bcd1dbdc03862846f26fbd93797153de954306620c1d00"}, + {file = "websockets-13.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4de299c947a54fca9ce1c5fd4a08eb92ffce91961becb13bd9195f7c6e71b47"}, + {file = "websockets-13.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:05c25f7b849702950b6fd0e233989bb73a0d2bc83faa3b7233313ca395205f6d"}, + {file = "websockets-13.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ede95125a30602b1691a4b1da88946bf27dae283cf30f22cd2cb8ca4b2e0d119"}, + {file = "websockets-13.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:addf0a16e4983280efed272d8cb3b2e05f0051755372461e7d966b80a6554e16"}, + {file = "websockets-13.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:06b3186e97bf9a33921fa60734d5ed90f2a9b407cce8d23c7333a0984049ef61"}, + {file = "websockets-13.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:eae368cac85adc4c7dc3b0d5f84ffcca609d658db6447387300478e44db70796"}, + {file = "websockets-13.0-cp38-cp38-win32.whl", hash = "sha256:337837ac788d955728b1ab01876d72b73da59819a3388e1c5e8e05c3999f1afa"}, + {file = "websockets-13.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66e00e42f25ca7e91076366303e11c82572ca87cc5aae51e6e9c094f315ab41"}, + {file = "websockets-13.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:94c1c02721139fe9940b38d28fb15b4b782981d800d5f40f9966264fbf23dcc8"}, + {file = "websockets-13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bd4ba86513430513e2aa25a441bb538f6f83734dc368a2c5d18afdd39097aa33"}, + {file = "websockets-13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a1ab8f0e0cadc5be5f3f9fa11a663957fecbf483d434762c8dfb8aa44948944a"}, + {file = "websockets-13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3670def5d3dfd5af6f6e2b3b243ea8f1f72d8da1ef927322f0703f85c90d9603"}, + {file = "websockets-13.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6058b6be92743358885ad6dcdecb378fde4a4c74d4dd16a089d07580c75a0e80"}, + {file = "websockets-13.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:516062a0a8ef5ecbfa4acbaec14b199fc070577834f9fe3d40800a99f92523ca"}, + {file = "websockets-13.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:da7e918d82e7bdfc6f66d31febe1b2e28a1ca3387315f918de26f5e367f61572"}, + {file = "websockets-13.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:9cc7f35dcb49a4e32db82a849fcc0714c4d4acc9d2273aded2d61f87d7f660b7"}, + {file = "websockets-13.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f5737c53eb2c8ed8f64b50d3dafd3c1dae739f78aa495a288421ac1b3de82717"}, + {file = "websockets-13.0-cp39-cp39-win32.whl", hash = "sha256:265e1f0d3f788ce8ef99dca591a1aec5263b26083ca0934467ad9a1d1181067c"}, + {file = "websockets-13.0-cp39-cp39-win_amd64.whl", hash = "sha256:4d70c89e3d3b347a7c4d3c33f8d323f0584c9ceb69b82c2ef8a174ca84ea3d4a"}, + {file = "websockets-13.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:602cbd010d8c21c8475f1798b705bb18567eb189c533ab5ef568bc3033fdf417"}, + {file = "websockets-13.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:bf8eb5dca4f484a60f5327b044e842e0d7f7cdbf02ea6dc4a4f811259f1f1f0b"}, + {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89d795c1802d99a643bf689b277e8604c14b5af1bc0a31dade2cd7a678087212"}, + {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:788bc841d250beccff67a20a5a53a15657a60111ef9c0c0a97fbdd614fae0fe2"}, + {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7334752052532c156d28b8eaf3558137e115c7871ea82adff69b6d94a7bee273"}, + {file = "websockets-13.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e7a1963302947332c3039e3f66209ec73b1626f8a0191649e0713c391e9f5b0d"}, + {file = "websockets-13.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2e1cf4e1eb84b4fd74a47688e8b0940c89a04ad9f6937afa43d468e71128cd68"}, + {file = "websockets-13.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:c026ee729c4ce55708a14b839ba35086dfae265fc12813b62d34ce33f4980c1c"}, + {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5f9d23fbbf96eefde836d9692670bfc89e2d159f456d499c5efcf6a6281c1af"}, + {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ad684cb7efce227d756bae3e8484f2e56aa128398753b54245efdfbd1108f2c"}, + {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e10b3fbed7be4a59831d3a939900e50fcd34d93716e433d4193a4d0d1d335d"}, + {file = "websockets-13.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d42a818e634f789350cd8fb413a3f5eec1cf0400a53d02062534c41519f5125c"}, + {file = "websockets-13.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e5ba5e9b332267d0f2c33ede390061850f1ac3ee6cd1bdcf4c5ea33ead971966"}, + {file = "websockets-13.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9af457ed593e35f467140d8b61d425495b127744a9d65d45a366f8678449a23"}, + {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcea3eb58c09c3a31cc83b45c06d5907f02ddaf10920aaa6443975310f699b95"}, + {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c210d1460dc8d326ffdef9703c2f83269b7539a1690ad11ae04162bc1878d33d"}, + {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b32f38bc81170fd56d0482d505b556e52bf9078b36819a8ba52624bd6667e39e"}, + {file = "websockets-13.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:81a11a1ddd5320429db47c04d35119c3e674d215173d87aaeb06ae80f6e9031f"}, + {file = "websockets-13.0-py3-none-any.whl", hash = "sha256:dbbac01e80aee253d44c4f098ab3cc17c822518519e869b284cfbb8cd16cc9de"}, + {file = "websockets-13.0.tar.gz", hash = "sha256:b7bf950234a482b7461afdb2ec99eee3548ec4d53f418c7990bb79c620476602"}, ] [[package]] name = "werkzeug" -version = "3.0.3" +version = "3.0.4" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, + {file = "werkzeug-3.0.4-py3-none-any.whl", hash = "sha256:02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c"}, + {file = "werkzeug-3.0.4.tar.gz", hash = "sha256:34f2371506b250df4d4f84bfe7b0921e4762525762bbd936614909fe25cd7306"}, ] [package.dependencies] @@ -9322,13 +9866,13 @@ requests = "*" [[package]] name = "zipp" -version = "3.19.2" +version = "3.20.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, - {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, + {file = "zipp-3.20.0-py3-none-any.whl", hash = "sha256:58da6168be89f0be59beb194da1250516fdaa062ccebd30127ac65d30045e10d"}, + {file = "zipp-3.20.0.tar.gz", hash = "sha256:0145e43d89664cfe1a2e533adc75adafed82fe2da404b4bbb6b026c0157bdb31"}, ] [package.extras] @@ -9355,47 +9899,45 @@ test = ["zope.testrunner"] [[package]] name = "zope-interface" -version = "6.4.post2" +version = "7.0.1" description = "Interfaces for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "zope.interface-6.4.post2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2eccd5bef45883802848f821d940367c1d0ad588de71e5cabe3813175444202c"}, - {file = "zope.interface-6.4.post2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:762e616199f6319bb98e7f4f27d254c84c5fb1c25c908c2a9d0f92b92fb27530"}, - {file = "zope.interface-6.4.post2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ef8356f16b1a83609f7a992a6e33d792bb5eff2370712c9eaae0d02e1924341"}, - {file = "zope.interface-6.4.post2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e4fa5d34d7973e6b0efa46fe4405090f3b406f64b6290facbb19dcbf642ad6b"}, - {file = "zope.interface-6.4.post2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d22fce0b0f5715cdac082e35a9e735a1752dc8585f005d045abb1a7c20e197f9"}, - {file = "zope.interface-6.4.post2-cp310-cp310-win_amd64.whl", hash = "sha256:97e615eab34bd8477c3f34197a17ce08c648d38467489359cb9eb7394f1083f7"}, - {file = "zope.interface-6.4.post2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:599f3b07bde2627e163ce484d5497a54a0a8437779362395c6b25e68c6590ede"}, - {file = "zope.interface-6.4.post2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:136cacdde1a2c5e5bc3d0b2a1beed733f97e2dad8c2ad3c2e17116f6590a3827"}, - {file = "zope.interface-6.4.post2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47937cf2e7ed4e0e37f7851c76edeb8543ec9b0eae149b36ecd26176ff1ca874"}, - {file = "zope.interface-6.4.post2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f0a6be264afb094975b5ef55c911379d6989caa87c4e558814ec4f5125cfa2e"}, - {file = "zope.interface-6.4.post2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47654177e675bafdf4e4738ce58cdc5c6d6ee2157ac0a78a3fa460942b9d64a8"}, - {file = "zope.interface-6.4.post2-cp311-cp311-win_amd64.whl", hash = "sha256:e2fb8e8158306567a3a9a41670c1ff99d0567d7fc96fa93b7abf8b519a46b250"}, - {file = "zope.interface-6.4.post2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b912750b13d76af8aac45ddf4679535def304b2a48a07989ec736508d0bbfbde"}, - {file = "zope.interface-6.4.post2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4ac46298e0143d91e4644a27a769d1388d5d89e82ee0cf37bf2b0b001b9712a4"}, - {file = "zope.interface-6.4.post2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86a94af4a88110ed4bb8961f5ac72edf782958e665d5bfceaab6bf388420a78b"}, - {file = "zope.interface-6.4.post2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:73f9752cf3596771c7726f7eea5b9e634ad47c6d863043589a1c3bb31325c7eb"}, - {file = "zope.interface-6.4.post2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00b5c3e9744dcdc9e84c24ed6646d5cf0cf66551347b310b3ffd70f056535854"}, - {file = "zope.interface-6.4.post2-cp312-cp312-win_amd64.whl", hash = "sha256:551db2fe892fcbefb38f6f81ffa62de11090c8119fd4e66a60f3adff70751ec7"}, - {file = "zope.interface-6.4.post2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96ac6b3169940a8cd57b4f2b8edcad8f5213b60efcd197d59fbe52f0accd66e"}, - {file = "zope.interface-6.4.post2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cebff2fe5dc82cb22122e4e1225e00a4a506b1a16fafa911142ee124febf2c9e"}, - {file = "zope.interface-6.4.post2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33ee982237cffaf946db365c3a6ebaa37855d8e3ca5800f6f48890209c1cfefc"}, - {file = "zope.interface-6.4.post2-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:fbf649bc77510ef2521cf797700b96167bb77838c40780da7ea3edd8b78044d1"}, - {file = "zope.interface-6.4.post2-cp37-cp37m-win_amd64.whl", hash = "sha256:4c0b208a5d6c81434bdfa0f06d9b667e5de15af84d8cae5723c3a33ba6611b82"}, - {file = "zope.interface-6.4.post2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d3fe667935e9562407c2511570dca14604a654988a13d8725667e95161d92e9b"}, - {file = "zope.interface-6.4.post2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a96e6d4074db29b152222c34d7eec2e2db2f92638d2b2b2c704f9e8db3ae0edc"}, - {file = "zope.interface-6.4.post2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:866a0f583be79f0def667a5d2c60b7b4cc68f0c0a470f227e1122691b443c934"}, - {file = "zope.interface-6.4.post2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5fe919027f29b12f7a2562ba0daf3e045cb388f844e022552a5674fcdf5d21f1"}, - {file = "zope.interface-6.4.post2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e0343a6e06d94f6b6ac52fbc75269b41dd3c57066541a6c76517f69fe67cb43"}, - {file = "zope.interface-6.4.post2-cp38-cp38-win_amd64.whl", hash = "sha256:dabb70a6e3d9c22df50e08dc55b14ca2a99da95a2d941954255ac76fd6982bc5"}, - {file = "zope.interface-6.4.post2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:706efc19f9679a1b425d6fa2b4bc770d976d0984335eaea0869bd32f627591d2"}, - {file = "zope.interface-6.4.post2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d136e5b8821073e1a09dde3eb076ea9988e7010c54ffe4d39701adf0c303438"}, - {file = "zope.interface-6.4.post2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1730c93a38b5a18d24549bc81613223962a19d457cfda9bdc66e542f475a36f4"}, - {file = "zope.interface-6.4.post2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc2676312cc3468a25aac001ec727168994ea3b69b48914944a44c6a0b251e79"}, - {file = "zope.interface-6.4.post2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a62fd6cd518693568e23e02f41816adedfca637f26716837681c90b36af3671"}, - {file = "zope.interface-6.4.post2-cp39-cp39-win_amd64.whl", hash = "sha256:d3f7e001328bd6466b3414215f66dde3c7c13d8025a9c160a75d7b2687090d15"}, - {file = "zope.interface-6.4.post2.tar.gz", hash = "sha256:1c207e6f6dfd5749a26f5a5fd966602d6b824ec00d2df84a7e9a924e8933654e"}, + {file = "zope.interface-7.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ec4e87e6fdc511a535254daa122c20e11959ce043b4e3425494b237692a34f1c"}, + {file = "zope.interface-7.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:51d5713e8e38f2d3ec26e0dfdca398ed0c20abda2eb49ffc15a15a23eb8e5f6d"}, + {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8d51e5eb29e57d34744369cd08267637aa5a0fefc9b5d33775ab7ff2ebf2e3"}, + {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:55bbcc74dc0c7ab489c315c28b61d7a1d03cf938cc99cc58092eb065f120c3a5"}, + {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10ebac566dd0cec66f942dc759d46a994a2b3ba7179420f0e2130f88f8a5f400"}, + {file = "zope.interface-7.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7039e624bcb820f77cc2ff3d1adcce531932990eee16121077eb51d9c76b6c14"}, + {file = "zope.interface-7.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03bd5c0db82237bbc47833a8b25f1cc090646e212f86b601903d79d7e6b37031"}, + {file = "zope.interface-7.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f52050c6a10d4a039ec6f2c58e5b3ade5cc570d16cf9d102711e6b8413c90e6"}, + {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af0b33f04677b57843d529b9257a475d2865403300b48c67654c40abac2f9f24"}, + {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:696c2a381fc7876b3056711717dba5eddd07c2c9e5ccd50da54029a1293b6e43"}, + {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f89a420cf5a6f2aa7849dd59e1ff0e477f562d97cf8d6a1ee03461e1eec39887"}, + {file = "zope.interface-7.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:b59deb0ddc7b431e41d720c00f99d68b52cb9bd1d5605a085dc18f502fe9c47f"}, + {file = "zope.interface-7.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52f5253cca1b35eaeefa51abd366b87f48f8714097c99b131ba61f3fdbbb58e7"}, + {file = "zope.interface-7.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88d108d004e0df25224de77ce349a7e73494ea2cb194031f7c9687e68a88ec9b"}, + {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c203d82069ba31e1f3bc7ba530b2461ec86366cd4bfc9b95ec6ce58b1b559c34"}, + {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f3495462bc0438b76536a0e10d765b168ae636092082531b88340dc40dcd118"}, + {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192b7a792e3145ed880ff6b1a206fdb783697cfdb4915083bfca7065ec845e60"}, + {file = "zope.interface-7.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:400d06c9ec8dbcc96f56e79376297e7be07a315605c9a2208720da263d44d76f"}, + {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c1dff87b30fd150c61367d0e2cdc49bb55f8b9fd2a303560bbc24b951573ae1"}, + {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f749ca804648d00eda62fe1098f229b082dfca930d8bad8386e572a6eafa7525"}, + {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ec212037becf6d2f705b7ed4538d56980b1e7bba237df0d8995cbbed29961dc"}, + {file = "zope.interface-7.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d33cb526efdc235a2531433fc1287fcb80d807d5b401f9b801b78bf22df560dd"}, + {file = "zope.interface-7.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b419f2144e1762ab845f20316f1df36b15431f2622ebae8a6d5f7e8e712b413c"}, + {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f1452d5d1f279184d5bdb663a3dc39902d9320eceb63276240791e849054b6"}, + {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ba4b3638d014918b918aa90a9c8370bd74a03abf8fcf9deb353b3a461a59a84"}, + {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc0615351221926a36a0fbcb2520fb52e0b23e8c22a43754d9cb8f21358c33c0"}, + {file = "zope.interface-7.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:ce6cbb852fb8f2f9bb7b9cdca44e2e37bce783b5f4c167ff82cb5f5128163c8f"}, + {file = "zope.interface-7.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5566fd9271c89ad03d81b0831c37d46ae5e2ed211122c998637130159a120cf1"}, + {file = "zope.interface-7.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da0cef4d7e3f19c3bd1d71658d6900321af0492fee36ec01b550a10924cffb9c"}, + {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32ca483e6ade23c7caaee9d5ee5d550cf4146e9b68d2fb6c68bac183aa41c37"}, + {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da21e7eec49252df34d426c2ee9cf0361c923026d37c24728b0fa4cc0599fd03"}, + {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a8195b99e650e6f329ce4e5eb22d448bdfef0406404080812bc96e2a05674cb"}, + {file = "zope.interface-7.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:19c829d52e921b9fe0b2c0c6a8f9a2508c49678ee1be598f87d143335b6a35dc"}, + {file = "zope.interface-7.0.1.tar.gz", hash = "sha256:f0f5fda7cbf890371a59ab1d06512da4f2c89a6ea194e595808123c863c38eff"}, ] [package.dependencies] @@ -9521,4 +10063,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "6eb1649ed473ab7916683beb3a9a09c1fc97f99845ee77adb811ea95b93b32e4" +content-hash = "e4c00268514d26bd07c6b72925e0e3b4558ec972895d252e60e9571e3ac38895" diff --git a/api/pyproject.toml b/api/pyproject.toml index 6af9bad5b0198a..f1d5e213ae8f94 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -69,7 +69,11 @@ ignore = [ ] [tool.ruff.format] -quote-style = "single" +exclude = [ + "core/**/*.py", + "models/**/*.py", + "migrations/**/*", +] [tool.pytest_env] OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" @@ -93,6 +97,8 @@ CODE_MAX_STRING_LENGTH = "80000" CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" CODE_EXECUTION_API_KEY = "dify-sandbox" FIRECRAWL_API_KEY = "fc-" +TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" +TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" [tool.poetry] name = "dify-api" @@ -108,7 +114,7 @@ authlib = "1.3.1" azure-identity = "1.16.1" azure-storage-blob = "12.13.0" beautifulsoup4 = "4.12.2" -boto3 = "1.34.136" +boto3 = "1.34.148" bs4 = "~0.0.1" cachetools = "~5.3.0" celery = "~5.3.6" @@ -142,17 +148,17 @@ langfuse = "^2.36.1" langsmith = "^0.1.77" mailchimp-transactional = "~1.0.50" markdown = "~3.5.1" -novita-client = "^0.5.6" +novita-client = "^0.5.7" numpy = "~1.26.4" openai = "~1.29.0" +openpyxl = "~3.1.5" oss2 = "2.18.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.8.2" -pydantic-settings = "~2.3.4" +pydantic-settings = "~2.4.0" pydantic_extra_types = "~2.9.0" -pydub = "~0.25.1" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" python = ">=3.10,<3.13" @@ -163,7 +169,6 @@ readabilipy = "0.2.0" redis = { version = "~5.0.3", extras = ["hiredis"] } replicate = "~0.22.0" resend = "~0.7.0" -safetensors = "~0.4.3" scikit-learn = "^1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" @@ -177,8 +182,19 @@ werkzeug = "~3.0.1" xinference-client = "0.13.3" yarl = "~1.9.4" zhipuai = "1.0.7" +# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. + +############################################################ +# Related transparent dependencies with pinned verion +# required by main implementations +############################################################ +azure-ai-ml = "^1.19.0" +azure-ai-inference = "^1.0.0b3" +volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"} +[tool.poetry.group.indriect.dependencies] +kaleido = "0.2.1" rank-bm25 = "~0.2.2" -openpyxl = "^3.1.5" +safetensors = "~0.4.3" ############################################################ # Tool dependencies required by tool implementations @@ -186,9 +202,10 @@ openpyxl = "^3.1.5" [tool.poetry.group.tool.dependencies] arxiv = "2.1.0" +cloudscraper = "1.2.71" matplotlib = "~3.8.2" newspaper3k = "0.2.8" -duckduckgo-search = "^6.1.8" +duckduckgo-search = "^6.2.6" jsonpath-ng = "1.6.1" numexpr = "~2.9.0" opensearch-py = "2.4.0" @@ -197,26 +214,25 @@ twilio = "~9.0.4" vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } wikipedia = "1.4.0" yfinance = "~0.2.40" -cloudscraper = "1.2.71" - +nltk = "3.8.1" ############################################################ # VDB dependencies required by vector store clients ############################################################ [tool.poetry.group.vdb.dependencies] +alibabacloud_gpdb20160503 = "~3.8.0" +alibabacloud_tea_openapi = "~0.3.9" chromadb = "0.5.1" +clickhouse-connect = "~0.7.16" +elasticsearch = "8.14.0" oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" -pymysql = "1.1.1" tcvectordb = "1.3.2" tidb-vector = "0.0.9" qdrant-client = "1.7.3" weaviate-client = "~3.21.0" -alibabacloud_gpdb20160503 = "~3.8.0" -alibabacloud_tea_openapi = "~0.3.9" -clickhouse-connect = "~0.7.16" ############################################################ # Dev dependencies for running tests @@ -227,7 +243,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" -pytest = "~8.1.1" +pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" pytest-mock = "~3.14.0" @@ -240,5 +256,5 @@ pytest-mock = "~3.14.0" optional = true [tool.poetry.group.lint.dependencies] -ruff = "~0.5.1" dotenv-linter = "~0.5.0" +ruff = "~0.6.1" diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index ccc1062266a02f..67d070682867bb 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -11,27 +11,32 @@ from models.dataset import Embedding -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_embedding_cache_task(): - click.echo(click.style('Start clean embedding cache.', fg='green')) + click.echo(click.style("Start clean embedding cache.", fg="green")) clean_days = int(dify_config.CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: try: - embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \ - .order_by(Embedding.created_at.desc()).limit(100).all() + embedding_ids = ( + db.session.query(Embedding.id) + .filter(Embedding.created_at < thirty_days_ago) + .order_by(Embedding.created_at.desc()) + .limit(100) + .all() + ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] except NotFound: break if embedding_ids: for embedding_id in embedding_ids: - db.session.execute(text( - "DELETE FROM embeddings WHERE id = :embedding_id" - ), {'embedding_id': embedding_id}) + db.session.execute( + text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id} + ) db.session.commit() else: break end_at = time.perf_counter() - click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index b2b2f82b786f5e..3d799bfd4ef732 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -12,9 +12,9 @@ from models.dataset import Dataset, DatasetQuery, Document -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_unused_datasets_task(): - click.echo(click.style('Start clean unused datasets indexes.', fg='green')) + click.echo(click.style("Start clean unused datasets indexes.", fg="green")) clean_days = dify_config.CLEAN_DAY_SETTING start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) @@ -22,40 +22,44 @@ def clean_unused_datasets_task(): while True: try: # Subquery for counting new documents - document_subquery_new = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at > thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Subquery for counting old documents - document_subquery_old = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at < thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Main query with join and filter - datasets = (db.session.query(Dataset) - .outerjoin( - document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id - ).outerjoin( - document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id - ).filter( - Dataset.created_at < thirty_days_ago, - func.coalesce(document_subquery_new.c.document_count, 0) == 0, - func.coalesce(document_subquery_old.c.document_count, 0) > 0 - ).order_by( - Dataset.created_at.desc() - ).paginate(page=page, per_page=50)) + datasets = ( + db.session.query(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < thirty_days_ago, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break @@ -63,10 +67,11 @@ def clean_unused_datasets_task(): break page += 1 for dataset in datasets: - dataset_query = db.session.query(DatasetQuery).filter( - DatasetQuery.created_at > thirty_days_ago, - DatasetQuery.dataset_id == dataset.id - ).all() + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id) + .all() + ) if not dataset_query or len(dataset_query) == 0: try: # remove index @@ -74,17 +79,14 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = { - Document.enabled: False - } + update_params = {Document.enabled: False} Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() - click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), - fg='green')) + click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: click.echo( - click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) end_at = time.perf_counter() - click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/services/__init__.py b/api/services/__init__.py index 6891436314b299..5163862cc12781 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,3 +1,3 @@ from . import errors -__all__ = ['errors'] +__all__ = ["errors"] diff --git a/api/services/account_service.py b/api/services/account_service.py index d73cec2697f975..e1b70fc9ed93c7 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -39,12 +39,7 @@ class AccountService: - - reset_password_rate_limiter = RateLimiter( - prefix="reset_password_rate_limit", - max_attempts=5, - time_window=60 * 60 - ) + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) @staticmethod def load_user(user_id: str) -> None | Account: @@ -55,12 +50,15 @@ def load_user(user_id: str) -> None | Account: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: raise Unauthorized("Account is banned or closed.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( + account_id=account.id, current=True + ).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if not available_ta: return None @@ -74,14 +72,13 @@ def load_user(user_id: str) -> None | Account: return account - @staticmethod def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): payload = { "user_id": account.id, "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, "iss": dify_config.EDITION, - "sub": 'Console API Passport', + "sub": "Console API Passport", } token = PassportService().issue(payload) @@ -93,10 +90,10 @@ def authenticate(email: str, password: str) -> Account: account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - raise AccountLoginError('Account is banned or closed.') + raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -104,7 +101,7 @@ def authenticate(email: str, password: str) -> Account: db.session.commit() if account.password is None or not compare_password(password, account.password, account.password_salt): - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") return account @staticmethod @@ -129,11 +126,9 @@ def update_account_password(account, password, new_password): return account @staticmethod - def create_account(email: str, - name: str, - interface_language: str, - password: Optional[str] = None, - interface_theme: str = 'light') -> Account: + def create_account( + email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" + ) -> Account: """create account""" account = Account() account.email = email @@ -155,7 +150,7 @@ def create_account(email: str, account.interface_theme = interface_theme # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, 'UTC') + account.timezone = language_timezone_mapping.get(interface_language, "UTC") db.session.add(account) db.session.commit() @@ -166,8 +161,9 @@ def link_account_integrate(provider: str, open_id: str, account: Account) -> Non """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, - provider=provider).first() + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() if account_integrate: # If it exists, update the record @@ -176,15 +172,16 @@ def link_account_integrate(provider: str, open_id: str, account: Account) -> Non account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) else: # If it does not exist, create a new record - account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, - encrypted_token="") + account_integrate = AccountIntegrate( + account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" + ) db.session.add(account_integrate) db.session.commit() - logging.info(f'Account {account.id} linked {provider} account {open_id}.') + logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: - logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') - raise LinkAccountIntegrateError('Failed to link account.') from e + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod def close_account(account: Account) -> None: @@ -218,7 +215,7 @@ def login(account: Account, *, ip_address: Optional[str] = None): AccountService.update_last_login(account, ip_address=ip_address) exp = timedelta(days=30) token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) + redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) return token @staticmethod @@ -236,22 +233,18 @@ def send_reset_password_email(cls, account): if cls.reset_password_rate_limiter.is_rate_limited(account.email): raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") - token = TokenManager.generate_token(account, 'reset_password') - send_reset_password_mail_task.delay( - language=account.interface_language, - to=account.email, - token=token - ) + token = TokenManager.generate_token(account, "reset_password") + send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) cls.reset_password_rate_limiter.increment_rate_limit(account.email) return token @classmethod def revoke_reset_password_token(cls, token: str): - TokenManager.revoke_token(token, 'reset_password') + TokenManager.revoke_token(token, "reset_password") @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, 'reset_password') + return TokenManager.get_token_data(token, "reset_password") def _get_login_cache_key(*, account_id: str, token: str): @@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str): class TenantService: - @staticmethod def create_tenant(name: str) -> Tenant: """Create tenant""" @@ -273,33 +265,33 @@ def create_tenant(name: str) -> Tenant: return tenant @staticmethod - def create_owner_tenant_if_not_exist(account: Account): + def create_owner_tenant_if_not_exist(account: Account, name: Optional[str] = None): """Create owner tenant if not exist""" - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if available_ta: return - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + if name: + tenant = TenantService.create_tenant(name) + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() tenant_was_created.send(tenant) @staticmethod - def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountJoinRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): - logging.error(f'Tenant {tenant.id} has already an owner.') - raise Exception('Tenant already has an owner.') + logging.error(f"Tenant {tenant.id} has already an owner.") + raise Exception("Tenant already has an owner.") - ta = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=role - ) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) db.session.add(ta) db.session.commit() return ta @@ -307,9 +299,12 @@ def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return db.session.query(Tenant).join( - TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() + return ( + db.session.query(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .all() + ) @staticmethod def get_current_tenant_by_account(account: Account): @@ -333,16 +328,23 @@ def switch_tenant(account: Account, tenant_id: int = None) -> None: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( - TenantAccountJoin.account_id == account.id, - TenantAccountJoin.tenant_id == tenant_id, - Tenant.status == TenantStatus.NORMAL, - ).first() + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) + .filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ) + .first() + ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) + TenantAccountJoin.query.filter( + TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id + ).update({"current": False}) tenant_account_join.current = True # Set the current tenant for the account account.current_tenant_id = tenant_account_join.tenant_id @@ -354,9 +356,7 @@ def get_tenant_members(tenant: Tenant) -> list[Account]: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) ) @@ -375,11 +375,9 @@ def get_dataset_operator_members(tenant: Tenant) -> list[Account]: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == 'dataset_operator') + .filter(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -395,20 +393,25 @@ def get_dataset_operator_members(tenant: Tenant) -> list[Account]: def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountJoinRole) for role in roles): - raise ValueError('all roles must be TenantAccountJoinRole') + raise ValueError("all roles must be TenantAccountJoinRole") - return db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role.in_([role.value for role in roles]) - ).first() is not None + return ( + db.session.query(TenantAccountJoin) + .filter( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) + ) + .first() + is not None + ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: """Get the role of the current account for a given tenant""" - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == account.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .first() + ) return join.role if join else None @staticmethod @@ -420,29 +423,26 @@ def get_tenant_count() -> int: def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: """Check member permission""" perms = { - 'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], - 'remove': [TenantAccountRole.OWNER], - 'update': [TenantAccountRole.OWNER] + "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], + "remove": [TenantAccountRole.OWNER], + "update": [TenantAccountRole.OWNER], } - if action not in ['add', 'remove', 'update']: + if action not in ["add", "remove", "update"]: raise InvalidActionError("Invalid action.") if member: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=operator.id - ).first() + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: - raise NoPermissionError(f'No permission to {action} member.') + raise NoPermissionError(f"No permission to {action} member.") @staticmethod def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: """Remove member from tenant""" - if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): raise CannotOperateSelfError("Cannot operate self.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() @@ -455,23 +455,17 @@ def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Accoun @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" - TenantService.check_member_permission(tenant, operator, member, 'update') + TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=member.id - ).first() + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") - if new_role == 'owner': + if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - role='owner' - ).first() - current_owner_join.role = 'admin' + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join.role = "admin" # Update the role of the target member target_member_join.role = new_role @@ -480,8 +474,8 @@ def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: @staticmethod def dissolve_tenant(tenant: Tenant, operator: Account) -> None: """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): - raise NoPermissionError('No permission to dissolve tenant.') + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): + raise NoPermissionError("No permission to dissolve tenant.") db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.delete(tenant) db.session.commit() @@ -494,10 +488,9 @@ def get_custom_config(tenant_id: str) -> None: class RegisterService: - @classmethod def _get_invitation_token_key(cls, token: str) -> str: - return f'member_invite:token:{token}' + return f"member_invite:token:{token}" @classmethod def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: @@ -523,9 +516,7 @@ def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: TenantService.create_owner_tenant_if_not_exist(account) - dify_setup = DifySetup( - version=dify_config.CURRENT_VERSION - ) + dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) db.session.add(dify_setup) db.session.commit() except Exception as e: @@ -535,34 +526,35 @@ def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: db.session.query(Tenant).delete() db.session.commit() - logging.exception(f'Setup failed: {e}') - raise ValueError(f'Setup failed: {e}') + logging.exception(f"Setup failed: {e}") + raise ValueError(f"Setup failed: {e}") @classmethod - def register(cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None) -> Account: + def register( + cls, + email, + name, + password: Optional[str] = None, + open_id: Optional[str] = None, + provider: Optional[str] = None, + language: Optional[str] = None, + status: Optional[AccountStatus] = None, + ) -> Account: db.session.begin_nested() """Register account""" try: account = AccountService.create_account( - email=email, - name=name, - interface_language=language if language else languages[0], - password=password + email=email, name=name, interface_language=language if language else languages[0], password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) @@ -570,30 +562,29 @@ def register(cls, email, name, db.session.commit() except Exception as e: db.session.rollback() - logging.error(f'Register failed: {e}') - raise AccountRegisterError(f'Registration failed: {e}') from e + logging.error(f"Register failed: {e}") + raise AccountRegisterError(f"Registration failed: {e}") from e return account @classmethod - def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: + def invite_new_member( + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() if not account: - TenantService.check_member_permission(tenant, inviter, None, 'add') - name = email.split('@')[0] + TenantService.check_member_permission(tenant, inviter, None, "add") + name = email.split("@")[0] account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) else: - TenantService.check_member_permission(tenant, inviter, account, 'add') - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=account.id - ).first() + TenantService.check_member_permission(tenant, inviter, account, "add") + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -609,7 +600,7 @@ def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str language=account.interface_language, to=email, token=token, - inviter_name=inviter.name if inviter else 'Dify', + inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, ) @@ -619,23 +610,19 @@ def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: token = str(uuid.uuid4()) invitation_data = { - 'account_id': account.id, - 'email': account.email, - 'workspace_id': tenant.id, + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, } expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex( - cls._get_invitation_token_key(token), - expiryHours * 60 * 60, - json.dumps(invitation_data) - ) + redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) return token @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) @@ -646,17 +633,21 @@ def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str if not invitation_data: return None - tenant = db.session.query(Tenant).filter( - Tenant.id == invitation_data['workspace_id'], - Tenant.status == 'normal' - ).first() + tenant = ( + db.session.query(Tenant) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .first() + ) if not tenant: return None - tenant_account = db.session.query(Account, TenantAccountJoin.role).join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() + tenant_account = ( + db.session.query(Account, TenantAccountJoin.role) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .first() + ) if not tenant_account: return None @@ -665,29 +656,29 @@ def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str if not account: return None - if invitation_data['account_id'] != str(account.id): + if invitation_data["account_id"] != str(account.id): return None return { - 'account': account, - 'data': invitation_data, - 'tenant': tenant, + "account": account, + "data": invitation_data, + "tenant": tenant, } @classmethod def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() - cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" account_id = redis_client.get(cache_key) if not account_id: return None return { - 'account_id': account_id.decode('utf-8'), - 'email': email, - 'workspace_id': workspace_id, + "account_id": account_id.decode("utf-8"), + "email": email, + "workspace_id": workspace_id, } else: data = redis_client.get(cls._get_invitation_token_key(token)) diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 213df262223d8a..d2cd7bea67c5b6 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,3 @@ - import copy from core.prompt.prompt_templates.advanced_prompt_templates import ( @@ -17,59 +16,78 @@ class AdvancedPromptTemplateService: - @classmethod def get_prompt(cls, args: dict) -> dict: - app_mode = args['app_mode'] - model_mode = args['model_mode'] - model_name = args['model_name'] - has_context = args['has_context'] + app_mode = args["app_mode"] + model_mode = args["model_mode"] + model_name = args["model_name"] + has_context = args["has_context"] - if 'baichuan' in model_name.lower(): + if "baichuan" in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] - + if has_context == "true": + prompt_template["completion_prompt_config"]["prompt"]["text"] = ( + context + prompt_template["completion_prompt_config"]["prompt"]["text"] + ) + return prompt_template @classmethod def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] - + if has_context == "true": + prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( + context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] + ) + return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index ba5fd93326a96f..887fb878b92db1 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -10,59 +10,65 @@ class AgentService: @classmethod - def get_agent_logs(cls, app_model: App, - conversation_id: str, - message_id: str) -> dict: + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: """ Service to get agent logs """ - conversation: Conversation = db.session.query(Conversation).filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - ).first() + conversation: Conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + ) + .first() + ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = db.session.query(Message).filter( - Message.id == message_id, - Message.conversation_id == conversation_id, - ).first() + message: Message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.conversation_id == conversation_id, + ) + .first() + ) if not message: raise ValueError(f"Message not found: {message_id}") - + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if conversation.from_end_user_id: # only select name field - executor = db.session.query(EndUser, EndUser.name).filter( - EndUser.id == conversation.from_end_user_id - ).first() + executor = ( + db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + ) else: - executor = db.session.query(Account, Account.name).filter( - Account.id == conversation.from_account_id - ).first() - + executor = ( + db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() + ) + if executor: executor = executor.name else: - executor = 'Unknown' + executor = "Unknown" timezone = pytz.timezone(current_user.timezone) result = { - 'meta': { - 'status': 'success', - 'executor': executor, - 'start_time': message.created_at.astimezone(timezone).isoformat(), - 'elapsed_time': message.provider_response_latency, - 'total_tokens': message.answer_tokens + message.message_tokens, - 'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'), - 'iterations': len(agent_thoughts), + "meta": { + "status": "success", + "executor": executor, + "start_time": message.created_at.astimezone(timezone).isoformat(), + "elapsed_time": message.provider_response_latency, + "total_tokens": message.answer_tokens + message.message_tokens, + "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), + "iterations": len(agent_thoughts), }, - 'iterations': [], - 'files': message.files, + "iterations": [], + "files": message.files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) @@ -86,12 +92,12 @@ def find_agent_tool(tool_name: str): tool_input = tool_inputs.get(tool_name, {}) tool_output = tool_outputs.get(tool_name, {}) tool_meta_data = tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - if tool_config.get('tool_provider_type', '') != 'dataset-retrieval': + tool_config = tool_meta_data.get("tool_config", {}) + if tool_config.get("tool_provider_type", "") != "dataset-retrieval": tool_icon = ToolManager.get_tool_icon( tenant_id=app_model.tenant_id, - provider_type=tool_config.get('tool_provider_type', ''), - provider_id=tool_config.get('tool_provider', ''), + provider_type=tool_config.get("tool_provider_type", ""), + provider_id=tool_config.get("tool_provider", ""), ) if not tool_icon: tool_entity = find_agent_tool(tool_name) @@ -102,30 +108,34 @@ def find_agent_tool(tool_name: str): provider_id=tool_entity.provider_id, ) else: - tool_icon = '' - - tool_calls.append({ - 'status': 'success' if not tool_meta_data.get('error') else 'error', - 'error': tool_meta_data.get('error'), - 'time_cost': tool_meta_data.get('time_cost', 0), - 'tool_name': tool_name, - 'tool_label': tool_label, - 'tool_input': tool_input, - 'tool_output': tool_output, - 'tool_parameters': tool_meta_data.get('tool_parameters', {}), - 'tool_icon': tool_icon, - }) - - result['iterations'].append({ - 'tokens': agent_thought.tokens, - 'tool_calls': tool_calls, - 'tool_raw': { - 'inputs': agent_thought.tool_input, - 'outputs': agent_thought.observation, - }, - 'thought': agent_thought.thought, - 'created_at': agent_thought.created_at.isoformat(), - 'files': agent_thought.files, - }) - - return result \ No newline at end of file + tool_icon = "" + + tool_calls.append( + { + "status": "success" if not tool_meta_data.get("error") else "error", + "error": tool_meta_data.get("error"), + "time_cost": tool_meta_data.get("time_cost", 0), + "tool_name": tool_name, + "tool_label": tool_label, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_parameters": tool_meta_data.get("tool_parameters", {}), + "tool_icon": tool_icon, + } + ) + + result["iterations"].append( + { + "tokens": agent_thought.tokens, + "tool_calls": tool_calls, + "tool_raw": { + "inputs": agent_thought.tool_input, + "outputs": agent_thought.observation, + }, + "thought": agent_thought.thought, + "created_at": agent_thought.created_at.isoformat(), + "files": agent_thought.files, + } + ) + + return result diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index addcde44ed9a4a..3cc6c51c2d323f 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -23,21 +23,18 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - if args.get('message_id'): - message_id = str(args['message_id']) + if args.get("message_id"): + message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -45,159 +42,166 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa annotation = message.annotation # save the message annotation if annotation: - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] else: annotation = MessageAnnotation( app_id=app.id, conversation_id=message.conversation_id, message_id=message.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + content=args["answer"], + question=args["question"], + account_id=current_user.id, ) else: annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(enable_app_annotation_job_key, 'waiting') - enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, - args['score_threshold'], - args['embedding_provider_name'], args['embedding_model_name']) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + redis_client.setnx(enable_app_annotation_job_key, "waiting") + enable_annotation_reply_task.delay( + str(job_id), + app_id, + current_user.id, + current_user.current_tenant_id, + args["score_threshold"], + args["embedding_provider_name"], + args["embedding_model_name"], + ) + return {"job_id": job_id, "job_status": "waiting"} @classmethod def disable_app_annotation(cls, app_id: str) -> dict: - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(disable_app_annotation_job_key, 'waiting') + redis_client.setnx(disable_app_annotation_job_key, "waiting") disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") if keyword: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( - or_( - MessageAnnotation.question.ilike('%{}%'.format(keyword)), - MessageAnnotation.content.ilike('%{}%'.format(keyword)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike("%{}%".format(keyword)), + MessageAnnotation.content.ilike("%{}%".format(keyword)), + ) ) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) else: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotations.items, annotations.total @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()).all()) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .all() + ) return annotations @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -207,30 +211,34 @@ def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: if not annotation: raise NotFound("Annotation not found") - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - update_annotation_to_index_task.delay(annotation.id, annotation.question, - current_user.current_tenant_id, - app_id, app_annotation_setting.collection_binding_id) + update_annotation_to_index_task.delay( + annotation.id, + annotation.question, + current_user.current_tenant_id, + app_id, + app_annotation_setting.collection_binding_id, + ) return annotation @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -242,33 +250,34 @@ def delete_app_annotation(cls, app_id: str, annotation_id: str): db.session.delete(annotation) - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - delete_annotation_index_task.delay(annotation.id, app_id, - current_user.current_tenant_id, - app_annotation_setting.collection_binding_id) + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -278,10 +287,7 @@ def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - content = { - 'question': row[0], - 'answer': row[1] - } + content = {"question": row[0], "answer": row[1]} result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -293,28 +299,24 @@ def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_import_annotations_task.delay(str(job_id), result, app_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_import_annotations_task.delay( + str(job_id), result, app_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return { - 'error_msg': str(e) - } - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"error_msg": str(e)} + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -324,12 +326,15 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim if not annotation: raise NotFound("Annotation not found") - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.app_id == app_id, - AppAnnotationHitHistory.annotation_id == annotation_id, - ) - .order_by(AppAnnotationHitHistory.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotation_hit_histories.items, annotation_hit_histories.total @classmethod @@ -341,15 +346,21 @@ def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: return annotation @classmethod - def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, - annotation_content: str, query: str, user_id: str, - message_id: str, from_source: str, score: float): + def add_annotation_history( + cls, + annotation_id: str, + app_id: str, + annotation_question: str, + annotation_content: str, + query: str, + user_id: str, + message_id: str, + from_source: str, + score: float, + ): # add hit count to annotation - db.session.query(MessageAnnotation).filter( - MessageAnnotation.id == annotation_id - ).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, - synchronize_session=False + db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) annotation_hit_history = AppAnnotationHitHistory( @@ -361,7 +372,7 @@ def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_ques score=score, message_id=message_id, annotation_question=annotation_question, - annotation_content=annotation_content + annotation_content=annotation_content, ) db.session.add(annotation_hit_history) db.session.commit() @@ -369,17 +380,18 @@ def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_ques @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -388,32 +400,34 @@ def get_app_annotation_setting_by_app_id(cls, app_id: str): "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } - return { - "enabled": False - } + return {"enabled": False} @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id, - AppAnnotationSetting.id == annotation_setting_id, - ).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting) + .filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ) + .first() + ) if not annotation_setting: raise NotFound("App annotation not found") - annotation_setting.score_threshold = args['score_threshold'] + annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) @@ -427,6 +441,6 @@ def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 8441bbedb344bc..601d67d2fba4e3 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -5,13 +5,14 @@ class APIBasedExtensionService: - @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .order_by(APIBasedExtension.created_at.desc()) \ - .all() + extension_list = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + .all() + ) for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) @@ -35,10 +36,12 @@ def delete(extension_data: APIBasedExtension) -> None: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .filter_by(id=api_based_extension_id) \ + extension = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .filter_by(id=api_based_extension_id) .first() + ) if not extension: raise ValueError("API based extension is not found") @@ -55,20 +58,24 @@ def _validation(cls, extension_data: APIBasedExtension) -> None: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ - .filter(APIBasedExtension.id != extension_data.id) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .filter(APIBasedExtension.id != extension_data.id) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") @@ -92,7 +99,7 @@ def _ping_connection(extension_data: APIBasedExtension) -> None: try: client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) resp = client.request(point=APIBasedExtensionPoint.PING, params={}) - if resp.get('result') != 'pong': + if resp.get("result") != "pong": raise ValueError(resp) except Exception as e: raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 3764166333255d..a2aa15ed4bd8b2 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -13,9 +13,9 @@ logger = logging.getLogger(__name__) -current_dsl_version = "0.1.0" +current_dsl_version = "0.1.1" dsl_to_dify_version_mapping: dict[str, str] = { - "0.1.0": "0.6.0", # dsl version -> from dify version + "0.1.1": "0.6.0", # dsl version -> from dify version } @@ -75,40 +75,44 @@ def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, accoun # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # get app basic info - name = args.get("name") if args.get("name") else app_data.get('name') - description = args.get("description") if args.get("description") else app_data.get('description', '') - icon = args.get("icon") if args.get("icon") else app_data.get('icon') - icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') + name = args.get("name") if args.get("name") else app_data.get("name") + description = args.get("description") if args.get("description") else app_data.get("description", "") + icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") + icon = args.get("icon") if args.get("icon") else app_data.get("icon") + icon_background = ( + args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") + ) # import dsl and create app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, name=name, description=description, + icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get('model_config'), + model_config_data=import_data.get("model_config"), account=account, name=name, description=description, + icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) else: raise ValueError("Invalid app mode") @@ -131,27 +135,26 @@ def import_and_overwrite_workflow(cls, app_model: App, data: str, account: Accou # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: raise ValueError("Only support import workflow in advanced-chat or workflow app.") - if app_data.get('mode') != app_model.mode: - raise ValueError( - f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + if app_data.get("mode") != app_model.mode: + raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, ) @classmethod - def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: """ Export app :param app_model: App instance @@ -165,18 +168,20 @@ def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: "app": { "name": app_model.name, "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "description": app_model.description - } + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + }, } if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data) + return yaml.dump(export_data, allow_unicode=True) @classmethod def _check_or_fix_dsl(cls, import_data: dict) -> dict: @@ -185,30 +190,35 @@ def _check_or_fix_dsl(cls, import_data: dict) -> dict: :param import_data: import data """ - if not import_data.get('version'): - import_data['version'] = "0.1.0" + if not import_data.get("version"): + import_data["version"] = "0.1.0" - if not import_data.get('kind') or import_data.get('kind') != "app": - import_data['kind'] = "app" + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" - if import_data.get('version') != current_dsl_version: + if import_data.get("version") != current_dsl_version: # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning(f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.") + logger.warning( + f"DSL version {import_data.get('version')} is not compatible " + f"with current version {current_dsl_version}, related to " + f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." + ) return import_data @classmethod - def _import_and_create_new_workflow_based_app(cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: dict, - account: Account, - name: str, - description: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_workflow_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + workflow_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Import app dsl and create new workflow based app @@ -218,12 +228,12 @@ def _import_and_create_new_workflow_based_app(cls, :param account: Account instance :param name: app name :param description: app description + :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") app = cls._create_app( tenant_id=tenant_id, @@ -231,35 +241,34 @@ def _import_and_create_new_workflow_based_app(cls, account=account, name=name, description=description, + icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) # init draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = workflow_data.get("conversation_variables") or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('../core/app/features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("../core/app/features", {}), unique_hash=None, account=account, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) - workflow_service.publish_workflow( - app_model=app, - account=account, - draft_workflow=draft_workflow - ) + workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) return app @classmethod - def _import_and_overwrite_workflow_based_app(cls, - app_model: App, - workflow_data: dict, - account: Account) -> Workflow: + def _import_and_overwrite_workflow_based_app( + cls, app_model: App, workflow_data: dict, account: Account + ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -268,8 +277,7 @@ def _import_and_overwrite_workflow_based_app(cls, :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -280,29 +288,35 @@ def _import_and_overwrite_workflow_based_app(cls, unique_hash = None # sync draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = workflow_data.get("conversation_variables") or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), unique_hash=unique_hash, account=account, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) return draft_workflow @classmethod - def _import_and_create_new_model_config_based_app(cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: dict, - account: Account, - name: str, - description: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_model_config_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + model_config_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Import app dsl and create new model config based app @@ -316,8 +330,7 @@ def _import_and_create_new_model_config_based_app(cls, :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument " - "when app mode is chat, agent-chat or completion") + raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion") app = cls._create_app( tenant_id=tenant_id, @@ -325,35 +338,38 @@ def _import_and_create_new_model_config_based_app(cls, account=account, name=name, description=description, + icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(model_config_data) app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id db.session.add(app_model_config) db.session.commit() app.app_model_config_id = app_model_config.id - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + app_model_config_was_updated.send(app, app_model_config=app_model_config) return app @classmethod - def _create_app(cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon: str, - icon_background: str) -> App: + def _create_app( + cls, + tenant_id: str, + app_mode: AppMode, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Create new app @@ -362,6 +378,7 @@ def _create_app(cls, :param account: Account instance :param name: app name :param description: app description + :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background """ @@ -370,10 +387,13 @@ def _create_app(cls, mode=app_mode.value, name=name, description=description, + icon_type=icon_type, icon=icon, icon_background=icon_background, enable_site=True, - enable_api=True + enable_api=True, + created_by=account.id, + updated_by=account.id, ) db.session.add(app) @@ -395,7 +415,7 @@ def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, incl if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - export_data['workflow'] = workflow.to_dict(include_secret=include_secret) + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) @classmethod def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: @@ -408,4 +428,4 @@ def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data['model_config'] = app_model_config.to_dict() + export_data["model_config"] = app_model_config.to_dict() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index cff4ba8af9dd46..747505977fe282 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,8 @@ from collections.abc import Generator from typing import Any, Union +from openai._exceptions import RateLimitError + from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator @@ -10,18 +12,20 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService class AppGenerateService: - @classmethod - def generate(cls, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - streaming: bool = True, - ): + def generate( + cls, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ App Content Generate :param app_model: app model @@ -37,51 +41,56 @@ def generate(cls, app_model: App, try: request_id = rate_limit.enter(request_id) if app_model.mode == AppMode.COMPLETION.value: - return rate_limit.generate(CompletionAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: - return rate_limit.generate(AgentChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.CHAT.value: - return rate_limit.generate(ChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") + except RateLimitError as e: + raise InvokeRateLimitError(str(e)) finally: if not streaming: rate_limit.exit(request_id) @@ -94,38 +103,31 @@ def _get_max_active_requests(app_model: App) -> int: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, - user: Union[Account, EndUser], - node_id: str, - args: Any, - streaming: bool = True): + def generate_single_iteration( + cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True + ): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: + def generate_more_like_this( + cls, + app_model: App, + user: Union[Account, EndUser], + message_id: str, + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[dict, Generator]: """ Generate more like this :param app_model: app model @@ -136,11 +138,7 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], :return: """ return CompletionAppGenerator().generate_more_like_this( - app_model=app_model, - message_id=message_id, - user=user, - invoke_from=invoke_from, - stream=streaming + app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming ) @classmethod @@ -157,12 +155,12 @@ def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: workflow = workflow_service.get_draft_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") else: # fetch published workflow by app_model workflow = workflow_service.get_published_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not published') + raise ValueError("Workflow not published") return workflow diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c84f6fbf454daf..a1ad2710534a84 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -5,7 +5,6 @@ class AppModelConfigService: - @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index e433bb59bbe994..462613fb7d17ba 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -33,27 +33,22 @@ def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: :param args: request args :return: """ - filters = [ - App.tenant_id == tenant_id, - App.is_universal == False - ] + filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args['mode'] == 'workflow': + if args["mode"] == "workflow": filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': + elif args["mode"] == "chat": filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent-chat': + elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': + elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) - if args.get('name'): - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - if args.get('tag_ids'): - target_ids = TagService.get_target_ids_by_tag_ids('app', - tenant_id, - args['tag_ids']) + if args.get("name"): + name = args["name"][:30] + filters.append(App.name.ilike(f"%{name}%")) + if args.get("tag_ids"): + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) if target_ids: filters.append(App.id.in_(target_ids)) else: @@ -61,9 +56,9 @@ def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False + page=args["page"], + per_page=args["limit"], + error_out=False, ) return app_models @@ -75,21 +70,20 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: :param args: request args :param account: Account instance """ - app_mode = AppMode.value_of(args['mode']) + app_mode = AppMode.value_of(args["mode"]) app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template.get('model_config') + default_model_config = app_template.get("model_config") default_model_config = default_model_config.copy() if default_model_config else None - if default_model_config and 'model' in default_model_config: + if default_model_config and "model" in default_model_config: # get model provider model_manager = ModelManager() # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -98,32 +92,43 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: model_instance = None if model_instance: - if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: - default_model_dict = default_model_config['model'] + if ( + model_instance.model == default_model_config["model"]["name"] + and model_instance.provider == default_model_config["model"]["provider"] + ): + default_model_dict = default_model_config["model"] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} + "provider": model_instance.provider, + "name": model_instance.model, + "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), + "completion_params": {}, } else: - default_model_dict = default_model_config['model'] - - default_model_config['model'] = json.dumps(default_model_dict) - - app = App(**app_template['app']) - app.name = args['name'] - app.description = args.get('description', '') - app.mode = args['mode'] - app.icon = args['icon'] - app.icon_background = args['icon_background'] + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id, model_type=ModelType.LLM + ) + default_model_config["model"]["provider"] = provider + default_model_config["model"]["name"] = model + default_model_dict = default_model_config["model"] + + default_model_config["model"] = json.dumps(default_model_dict) + + app = App(**app_template["app"]) + app.name = args["name"] + app.description = args.get("description", "") + app.mode = args["mode"] + app.icon_type = args.get("icon_type", "emoji") + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.tenant_id = tenant_id - app.api_rph = args.get('api_rph', 0) - app.api_rpm = args.get('api_rpm', 0) + app.api_rph = args.get("api_rph", 0) + app.api_rpm = args.get("api_rpm", 0) + app.created_by = account.id + app.updated_by = account.id db.session.add(app) db.session.flush() @@ -131,6 +136,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: if default_model_config: app_model_config = AppModelConfig(**default_model_config) app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id db.session.add(app_model_config) db.session.flush() @@ -151,7 +158,7 @@ def get_app(self, app: App) -> App: model_config: AppModelConfig = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue agent_tool_entity = AgentToolEntity(**tool) @@ -167,7 +174,7 @@ def get_app(self, app: App) -> App: tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app.id}' + identity_id=f"AGENT.{app.id}", ) # get decrypted parameters @@ -178,7 +185,7 @@ def get_app(self, app: App) -> App: masked_parameter = {} # override tool parameters - tool['tool_parameters'] = masked_parameter + tool["tool_parameters"] = masked_parameter except Exception as e: pass @@ -189,13 +196,14 @@ class ModifiedApp(App): """ Modified App class """ + def __init__(self, app): self.__dict__.update(app.__dict__) @property def app_model_config(self): return model_config - + app = ModifiedApp(app) return app @@ -207,11 +215,13 @@ def update_app(self, app: App, args: dict) -> App: :param args: request args :return: App instance """ - app.name = args.get('name') - app.description = args.get('description', '') - app.max_active_requests = args.get('max_active_requests') - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') + app.name = args.get("name") + app.description = args.get("description", "") + app.max_active_requests = args.get("max_active_requests") + app.icon_type = args.get("icon_type", "emoji") + app.icon = args.get("icon") + app.icon_background = args.get("icon_background") + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -228,6 +238,7 @@ def update_app_name(self, app: App, name: str) -> App: :return: App instance """ app.name = name + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -243,6 +254,7 @@ def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: """ app.icon = icon app.icon_background = icon_background + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -259,6 +271,7 @@ def update_app_site_status(self, app: App, enable_site: bool) -> App: return app app.enable_site = enable_site + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -275,6 +288,7 @@ def update_app_api_status(self, app: App, enable_api: bool) -> App: return app app.enable_api = enable_api + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -289,10 +303,7 @@ def delete_app(self, app: App) -> None: db.session.commit() # Trigger asynchronous deletion of app and related data - remove_app_and_related_data_task.delay( - tenant_id=app.tenant_id, - app_id=app.id - ) + remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) def get_app_meta(self, app_model: App) -> dict: """ @@ -302,9 +313,7 @@ def get_app_meta(self, app_model: App) -> dict: """ app_mode = AppMode.value_of(app_model.mode) - meta = { - 'tool_icons': {} - } + meta = {"tool_icons": {}} if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: workflow = app_model.workflow @@ -312,17 +321,19 @@ def get_app_meta(self, app_model: App) -> dict: return meta graph = workflow.graph_dict - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) tools = [] for node in nodes: - if node.get('data', {}).get('type') == 'tool': - node_data = node.get('data', {}) - tools.append({ - 'provider_type': node_data.get('provider_type'), - 'provider_id': node_data.get('provider_id'), - 'tool_name': node_data.get('tool_name'), - 'tool_parameters': {} - }) + if node.get("data", {}).get("type") == "tool": + node_data = node.get("data", {}) + tools.append( + { + "provider_type": node_data.get("provider_type"), + "provider_id": node_data.get("provider_id"), + "tool_name": node_data.get("tool_name"), + "tool_parameters": {}, + } + ) else: app_model_config: AppModelConfig = app_model.app_model_config @@ -332,30 +343,26 @@ def get_app_meta(self, app_model: App) -> dict: agent_config = app_model_config.agent_mode_dict or {} # get all tools - tools = agent_config.get('tools', []) + tools = agent_config.get("tools", []) - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/") + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get('provider_type') - provider_id = tool.get('provider_id') - tool_name = tool.get('tool_name') - if provider_type == 'builtin': - meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' - elif provider_type == 'api': + provider_type = tool.get("provider_type") + provider_id = tool.get("provider_id") + tool_name = tool.get("tool_name") + if provider_type == "builtin": + meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id - ).first() - meta['tool_icons'][tool_name] = json.loads(provider.icon) + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + ) + meta["tool_icons"][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { - "background": "#252525", - "content": "\ud83d\ude01" - } + meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 58c950816fdb86..05cd1c96a1d6a6 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -17,7 +17,7 @@ FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 -ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] +ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] logger = logging.getLogger(__name__) @@ -31,19 +31,19 @@ def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[st raise ValueError("Speech to text is not enabled") features_dict = workflow.features_dict - if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: app_model_config: AppModelConfig = app_model.app_model_config - if not app_model_config.speech_to_text_dict['enabled']: + if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") if file is None: raise NoAudioUploadedServiceError() extension = file.mimetype - if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]: + if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() file_content = file.read() @@ -55,20 +55,25 @@ def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[st model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.SPEECH2TEXT + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: raise ProviderNotSupportSpeechToTextServiceError() buffer = io.BytesIO(file_content) - buffer.name = 'temp.mp3' + buffer.name = "temp.mp3" return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: Optional[str] = None, - voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): + def transcript_tts( + cls, + app_model: App, + text: Optional[str] = None, + voice: Optional[str] = None, + end_user: Optional[str] = None, + message_id: Optional[str] = None, + ): from collections.abc import Generator from flask import Response, stream_with_context @@ -84,65 +89,56 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): raise ValueError("TTS is not enabled") features_dict = workflow.features_dict - if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') if voice is None else voice + voice = features_dict["text_to_speech"].get("voice") if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - if not text_to_speech_dict.get('enabled'): + if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice') if voice is None else voice + voice = text_to_speech_dict.get("voice") if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.TTS + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) try: if not voice: voices = model_instance.get_tts_voices() if voices: - voice = voices[0].get('value') + voice = voices[0].get("value") else: raise ValueError("Sorry, no voice available.") return model_instance.invoke_tts( - content_text=text_content.strip(), - user=end_user, - tenant_id=app_model.tenant_id, - voice=voice + content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice ) except Exception as e: raise e if message_id: - message = db.session.query(Message).filter( - Message.id == message_id - ).first() - if message.answer == '' and message.status == 'normal': + message = db.session.query(Message).filter(Message.id == message_id).first() + if message.answer == "" and message.status == "normal": return None else: response = invoke_tts(message.answer, app_model=app_model, voice=voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response else: response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TTS - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index ccd0023c44d84d..ae5b953b47f589 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,14 +1,12 @@ - from services.auth.firecrawl import FirecrawlAuth class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): - if provider == 'firecrawl': + if provider == "firecrawl": self.auth = FirecrawlAuth(credentials) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") def validate_credentials(self): return self.auth.validate_credentials() diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 43d0fbf98f2df7..e5f4a3ef6e12d3 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -7,39 +7,43 @@ class ApiKeyAuthService: - @staticmethod def get_provider_auth_list(tenant_id: str) -> list: - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).all() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .all() + ) return data_source_api_key_bindings @staticmethod def create_provider_auth(tenant_id: str, args: dict): - auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() + auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key - api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) - args['credentials']['config']['api_key'] = api_key + api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) + args["credentials"]["config"]["api_key"] = api_key data_source_api_key_binding = DataSourceApiKeyAuthBinding() data_source_api_key_binding.tenant_id = tenant_id - data_source_api_key_binding.category = args['category'] - data_source_api_key_binding.provider = args['provider'] - data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) + data_source_api_key_binding.category = args["category"] + data_source_api_key_binding.provider = args["provider"] + data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) db.session.add(data_source_api_key_binding) db.session.commit() @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.category == category, - DataSourceApiKeyAuthBinding.provider == provider, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).first() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False), + ) + .first() + ) if not data_source_api_key_bindings: return None credentials = json.loads(data_source_api_key_bindings.credentials) @@ -47,24 +51,24 @@ def get_auth_credentials(tenant_id: str, category: str, provider: str): @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.id == binding_id - ).first() + data_source_api_key_binding = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .first() + ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) db.session.commit() @classmethod def validate_api_key_auth_args(cls, args): - if 'category' not in args or not args['category']: - raise ValueError('category is required') - if 'provider' not in args or not args['provider']: - raise ValueError('provider is required') - if 'credentials' not in args or not args['credentials']: - raise ValueError('credentials is required') - if not isinstance(args['credentials'], dict): - raise ValueError('credentials must be a dictionary') - if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: - raise ValueError('auth_type is required') - + if "category" not in args or not args["category"]: + raise ValueError("category is required") + if "provider" not in args or not args["provider"]: + raise ValueError("provider is required") + if "credentials" not in args or not args["credentials"]: + raise ValueError("credentials is required") + if not isinstance(args["credentials"], dict): + raise ValueError("credentials must be a dictionary") + if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: + raise ValueError("auth_type is required") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 69e3fb43c79dab..30e4ee57c06b45 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -8,49 +8,40 @@ class FirecrawlAuth(ApiKeyAuthBase): def __init__(self, credentials: dict): super().__init__(credentials) - auth_type = credentials.get('auth_type') - if auth_type != 'bearer': - raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') - self.api_key = credentials.get('config').get('api_key', None) - self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") if not self.api_key: - raise ValueError('No API key provided') + raise ValueError("No API key provided") def validate_credentials(self): headers = self._prepare_headers() options = { - 'url': 'https://example.com', - 'crawlerOptions': { - 'excludes': [], - 'includes': [], - 'limit': 1 - }, - 'pageOptions': { - 'onlyMainContent': True - } + "url": "https://example.com", + "crawlerOptions": {"excludes": [], "includes": [], "limit": 1}, + "pageOptions": {"onlyMainContent": True}, } - response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", options, headers) if response.status_code == 200: return True else: self._handle_error(response) def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: if response.text: - error_message = json.loads(response.text).get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') - raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 539f2712bb9f38..911d2346415ce5 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -7,58 +7,40 @@ class BillingService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def get_info(cls, tenant_id: str): - params = {'tenant_id': tenant_id} + params = {"tenant_id": tenant_id} - billing_info = cls._send_request('GET', '/subscription/info', params=params) + billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info @classmethod - def get_subscription(cls, plan: str, - interval: str, - prefilled_email: str = '', - tenant_id: str = ''): - params = { - 'plan': plan, - 'interval': interval, - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/subscription/payment-link', params=params) + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): + params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/subscription/payment-link", params=params) @classmethod - def get_model_provider_payment_link(cls, - provider_name: str, - tenant_id: str, - account_id: str, - prefilled_email: str): + def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): params = { - 'provider_name': provider_name, - 'tenant_id': tenant_id, - 'account_id': account_id, - 'prefilled_email': prefilled_email + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": prefilled_email, } - return cls._send_request('GET', '/model-provider/payment-link', params=params) + return cls._send_request("GET", "/model-provider/payment-link", params=params) @classmethod - def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''): - params = { - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/invoices', params=params) + def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): + params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/invoices", params=params) @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -69,10 +51,11 @@ def _send_request(cls, method, endpoint, json=None, params=None): def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant_id, - TenantAccountJoin.account_id == current_user.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .first() + ) if not TenantAccountRole.is_privileged_role(join.role): - raise ValueError('Only team owner or team admin can perform this action') + raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 7b0d50a835dbf8..f7597b7f1fcd45 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -2,12 +2,15 @@ class CodeBasedExtensionService: - @staticmethod def get_code_based_extension(module: str) -> list[dict]: module_extensions = code_based_extension.module_extensions(module) - return [{ - 'name': module_extension.name, - 'label': module_extension.label, - 'form_schema': module_extension.form_schema - } for module_extension in module_extensions if not module_extension.builtin] + return [ + { + "name": module_extension.name, + "label": module_extension.label, + "form_schema": module_extension.form_schema, + } + for module_extension in module_extensions + if not module_extension.builtin + ] diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 82ee10ee78f095..7bfe59afa0ead7 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,7 @@ +from datetime import datetime, timezone from typing import Optional, Union -from sqlalchemy import or_ +from sqlalchemy import asc, desc, or_ from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -14,21 +15,27 @@ class ConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, - invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + include_ids: Optional[list] = None, + exclude_ids: Optional[list] = None, + sort_by: str = "-updated_at", + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) base_query = db.session.query(Conversation).filter( Conversation.is_deleted == False, Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) if include_ids is not None: @@ -37,47 +44,67 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) - if last_id: - last_conversation = base_query.filter( - Conversation.id == last_id, - ).first() + # define sort fields and directions + sort_field, sort_direction = cls._get_sort_params(sort_by) + if last_id: + last_conversation = base_query.filter(Conversation.id == last_id).first() if not last_conversation: raise LastConversationNotExistsError() - conversations = base_query.filter( - Conversation.created_at < last_conversation.created_at, - Conversation.id != last_conversation.id - ).order_by(Conversation.created_at.desc()).limit(limit).all() - else: - conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all() + # build filters based on sorting + filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) + base_query = base_query.filter(filter_condition) + + base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) + + conversations = base_query.limit(limit).all() has_more = False if len(conversations) == limit: - current_page_first_conversation = conversations[-1] - rest_count = base_query.filter( - Conversation.created_at < current_page_first_conversation.created_at, - Conversation.id != current_page_first_conversation.id - ).count() + current_page_last_conversation = conversations[-1] + rest_filter_condition = cls._build_filter_condition( + sort_field, sort_direction, current_page_last_conversation, is_next_page=True + ) + rest_count = base_query.filter(rest_filter_condition).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=conversations, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod - def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): + def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: + if sort_by.startswith("-"): + return sort_by[1:], desc + return sort_by, asc + + @classmethod + def _build_filter_condition( + cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + ): + field_value = getattr(reference_conversation, sort_field) + if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + return getattr(Conversation, sort_field) < field_value + else: + return getattr(Conversation, sort_field) > field_value + + @classmethod + def rename( + cls, + app_model: App, + conversation_id: str, + user: Optional[Union[Account, EndUser]], + name: str, + auto_generate: bool, + ): conversation = cls.get_conversation(app_model, conversation_id, user) if auto_generate: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return conversation @@ -85,11 +112,12 @@ def rename(cls, app_model: App, conversation_id: str, @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = db.session.query(Message) \ - .filter( - Message.app_id == app_model.id, - Message.conversation_id == conversation.id - ).order_by(Message.created_at.asc()).first() + message = ( + db.session.query(Message) + .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .order_by(Message.created_at.asc()) + .first() + ) if not message: raise MessageNotExistsError() @@ -109,15 +137,18 @@ def auto_generate_name(cls, app_model: App, conversation: Conversation): @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = db.session.query(Conversation) \ + conversation = ( + db.session.query(Conversation) .filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - Conversation.is_deleted == False - ).first() + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + Conversation.is_deleted == False, + ) + .first() + ) if not conversation: raise ConversationNotExistsError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 9052a0b7857aa2..d10df3f600beab 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -27,6 +27,7 @@ Dataset, DatasetCollectionBinding, DatasetPermission, + DatasetPermissionEnum, DatasetProcessRule, DatasetQuery, Document, @@ -54,7 +55,6 @@ class DatasetService: - @staticmethod def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None): query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by( @@ -63,10 +63,7 @@ def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, s if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by( - account_id=user.id, - tenant_id=tenant_id - ).all() + dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -80,83 +77,76 @@ def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, s if permitted_dataset_ids: query = query.filter( db.or_( - Dataset.permission == 'all_team_members', - db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id), - db.and_(Dataset.permission == 'partial_members', Dataset.id.in_(permitted_dataset_ids)) + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), ) ) else: query = query.filter( db.or_( - Dataset.permission == 'all_team_members', - db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id) + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), ) ) else: # if no user, only show datasets that are shared with all team members - query = query.filter(Dataset.permission == 'all_team_members') + query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f'%{search}%')) + query = query.filter(Dataset.name.ilike(f"%{search}%")) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: query = query.filter(Dataset.id.in_(target_ids)) else: return [], 0 - datasets = query.paginate( - page=page, - per_page=per_page, - max_per_page=100, - error_out=False - ) + datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @staticmethod def get_process_rules(dataset_id): # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict else: - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] - return { - 'mode': mode, - 'rules': rules - } + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] + return {"mode": mode, "rules": rules} @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter( - Dataset.id.in_(ids), - Dataset.tenant_id == tenant_id - ).paginate( + datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( page=1, per_page=len(ids), max_per_page=len(ids), error_out=False ) return datasets.items, datasets.total @staticmethod - def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): + def create_empty_dataset( + tenant_id: str, name: str, indexing_technique: Optional[str], account: Account, permission: Optional[str] = None + ): # check if dataset name already exists if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): - raise DatasetNameDuplicateError( - f'Dataset with name {name} already exists.' - ) + raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) @@ -165,26 +155,25 @@ def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME db.session.add(dataset) db.session.commit() return dataset @staticmethod def get_dataset(dataset_id): - return Dataset.query.filter_by( - id=dataset_id - ).first() + return Dataset.query.filter_by(id=dataset_id).first() @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ValueError( @@ -192,65 +181,56 @@ def check_dataset_model_setting(dataset): "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod - def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model:str): + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model + model=embedding_model, ) except LLMBadRequestError: raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) - + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod def update_dataset(dataset_id, data, user): - data.pop('partial_member_list', None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} + data.pop("partial_member_list", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_permission(dataset, user) action = None - if dataset.indexing_technique != data['indexing_technique']: + if dataset.indexing_technique != data["indexing_technique"]: # if update indexing_technique - if data['indexing_technique'] == 'economy': - action = 'remove' - filtered_data['embedding_model'] = None - filtered_data['embedding_model_provider'] = None - filtered_data['collection_binding_id'] = None - elif data['indexing_technique'] == 'high_quality': - action = 'add' + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -259,24 +239,25 @@ def update_dataset(dataset_id, data, user): except ProviderTokenNotInitError as ex: raise ValueError(ex.description) else: - if data['embedding_model_provider'] != dataset.embedding_model_provider or \ - data['embedding_model'] != dataset.embedding_model: - action = 'update' + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -285,11 +266,11 @@ def update_dataset(dataset_id, data, user): except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - filtered_data['updated_by'] = user.id - filtered_data['updated_at'] = datetime.datetime.now() + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() # update Retrieval model - filtered_data['retrieval_model'] = data['retrieval_model'] + filtered_data["retrieval_model"] = data["retrieval_model"] dataset.query.filter_by(id=dataset_id).update(filtered_data) @@ -300,7 +281,6 @@ def update_dataset(dataset_id, data, user): @staticmethod def delete_dataset(dataset_id, user): - dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -324,72 +304,57 @@ def dataset_use_check(dataset_id) -> bool: @staticmethod def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'only_me' and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'partial_members': - user_permission = DatasetPermission.query.filter_by( - dataset_id=dataset.id, account_id=user.id - ).first() + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): - if dataset.permission == 'only_me': + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") - elif dataset.permission == 'partial_members': + elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() ): - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \ - .order_by(db.desc(DatasetQuery.created_at)) \ - .paginate( - page=page, per_page=per_page, max_per_page=100, error_out=False + dataset_queries = ( + DatasetQuery.query.filter_by(dataset_id=dataset_id) + .order_by(db.desc(DatasetQuery.created_at)) + .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) ) return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): - return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ - .order_by(db.desc(AppDatasetJoin.created_at)).all() + return ( + AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + .order_by(db.desc(AppDatasetJoin.created_at)) + .all() + ) class DocumentService: DEFAULT_RULES = { - 'mode': 'custom', - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + }, } DOCUMENT_METADATA_SCHEMA = { @@ -482,58 +447,55 @@ class DocumentService: "commit_date": str, "commit_author": str, }, - "others": dict + "others": dict, } @staticmethod def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) return document @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() return document @staticmethod def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.enabled == True - ).all() + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() return documents @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.indexing_status.in_(['error', 'paused']) - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .all() + ) return documents @staticmethod def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.batch == batch, - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id, + ) + .all() + ) return documents @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == file_id). \ - one_or_none() + file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -547,13 +509,14 @@ def check_archived(document): def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - document_was_deleted.send(document.id, dataset_id=document.dataset_id, - doc_form=document.doc_form, file_id=file_id) + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + document_was_deleted.send( + document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id + ) db.session.delete(document) db.session.commit() @@ -562,15 +525,15 @@ def delete_document(document): def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise ValueError('Dataset not found.') + raise ValueError("Dataset not found.") document = DocumentService.get_document(dataset_id, document_id) if not document: - raise ValueError('Document not found.') + raise ValueError("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise ValueError('No permission.') + raise ValueError("No permission.") document.name = name @@ -591,7 +554,7 @@ def pause_document(document): db.session.add(document) db.session.commit() # set document paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.setnx(indexing_cache_key, "True") @staticmethod @@ -606,7 +569,7 @@ def recover_document(document): db.session.add(document) db.session.commit() # delete paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.delete(indexing_cache_key) # trigger async task recover_document_indexing_task.delay(document.dataset_id, document.id) @@ -615,12 +578,12 @@ def recover_document(document): def retry_document(dataset_id: str, documents: list[Document]): for document in documents: # add retry flag - retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + retry_indexing_cache_key = "document_{}_is_retried".format(document.id) cache_result = redis_client.get(retry_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) db.session.commit() @@ -632,14 +595,14 @@ def retry_document(dataset_id: str, documents: list[Document]): @staticmethod def sync_website_document(dataset_id: str, document: Document): # add sync flag - sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) + sync_indexing_cache_key = "document_{}_is_sync".format(document.id) cache_result = redis_client.get(sync_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" data_source_info = document.data_source_info_dict - data_source_info['mode'] = 'scrape' + data_source_info["mode"] = "scrape" document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() @@ -658,27 +621,28 @@ def get_documents_position(dataset_id): @staticmethod def save_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): - # check document limit features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if 'original_document_id' not in document_data or not document_data['original_document_id']: + if "original_document_id" not in document_data or not document_data["original_document_id"]: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -690,42 +654,41 @@ def save_document_with_dataset_id( dataset.data_source_type = document_data["data_source"]["type"] if not dataset.indexing_technique: - if 'indexing_technique' not in document_data \ - or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST: + if ( + "indexing_technique" not in document_data + or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST + ): raise ValueError("Indexing technique is required") dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( - 'retrieval_model' - ) else default_retrieval_model + dataset.retrieval_model = ( + document_data.get("retrieval_model") + if document_data.get("retrieval_model") + else default_retrieval_model + ) documents = [] - batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) @@ -738,14 +701,14 @@ def save_document_with_dataset_id( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() @@ -753,12 +716,13 @@ def save_document_with_dataset_id( document_ids = [] duplicate_document_ids = [] if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -769,34 +733,39 @@ def save_document_with_dataset_id( "upload_file_id": file_id, } # check duplicate - if document_data.get('duplicate', False): + if document_data.get("duplicate", False): document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='upload_file', + data_source_type="upload_file", enabled=True, - name=file_name + name=file_name, ).first() if document: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = datetime.datetime.utcnow() document.created_from = created_from - document.doc_form = document_data['doc_form'] - document.doc_language = document_data['doc_language'] + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) continue document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, file_name, batch + data_source_info, + created_from, + position, + account, + file_name, + batch, ) db.session.add(document) db.session.flush() @@ -804,47 +773,52 @@ def save_document_with_dataset_id( documents.append(document) position += 1 elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) - exist_document[data_source_info['notion_page_id']] = document.id + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: - if page['page_id'] not in exist_page_ids: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, page['page_name'], batch + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, ) db.session.add(document) db.session.flush() @@ -852,32 +826,37 @@ def save_document_with_dataset_id( documents.append(document) position += 1 else: - exist_document.pop(page['page_id']) + exist_document.pop(page["page_id"]) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } if len(url) > 255: - document_name = url[:200] + '...' + document_name = url[:200] + "..." else: document_name = url document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, document_name, batch + data_source_info, + created_from, + position, + account, + document_name, + batch, ) db.session.add(document) db.session.flush() @@ -899,15 +878,22 @@ def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: raise ValueError( - f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.' + f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." ) @staticmethod def build_document( - dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - document_language: str, data_source_info: dict, created_from: str, position: int, + dataset: Dataset, + process_rule_id: str, + data_source_type: str, + document_form: str, + document_language: str, + data_source_info: dict, + created_from: str, + position: int, account: Account, - name: str, batch: str + name: str, + batch: str, ): document = Document( tenant_id=dataset.tenant_id, @@ -921,7 +907,7 @@ def build_document( created_from=created_from, created_by=account.id, doc_form=document_form, - doc_language=document_language + doc_language=document_language, ) return document @@ -931,54 +917,57 @@ def get_tenant_documents_count(): Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, - Document.tenant_id == current_user.current_tenant_id + Document.tenant_id == current_user.current_tenant_id, ).count() return documents_count @staticmethod def update_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) - if document.display_status != 'available': + if document.display_status != "available": raise ValueError("Document is not available") # update document name - if document_data.get('name'): - document.name = document_data['name'] + if document_data.get("name"): + document.name = document_data["name"] # save process rule - if document_data.get('process_rule'): + if document_data.get("process_rule"): process_rule = document_data["process_rule"] if process_rule["mode"] == "custom": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get('data_source'): - file_name = '' + if document_data.get("data_source"): + file_name = "" data_source_info = {} if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -989,42 +978,42 @@ def update_document_with_dataset_id( "upload_file_id": file_id, } elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document to be waiting - document.indexing_status = 'waiting' + document.indexing_status = "waiting" document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -1032,13 +1021,11 @@ def update_document_with_dataset_id( document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data['doc_form'] + document.doc_form = document_data["doc_form"] db.session.add(document) db.session.commit() # update document segment - update_params = { - DocumentSegment.status: 're_segment' - } + update_params = {DocumentSegment.status: "re_segment"} DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task @@ -1052,15 +1039,15 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun if features.billing.enabled: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1070,42 +1057,37 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun embedding_model = None dataset_collection_binding_id = None retrieval_model = None - if document_data['indexing_technique'] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get('retrieval_model'): - retrieval_model = document_data['retrieval_model'] + if document_data.get("retrieval_model"): + retrieval_model = document_data["retrieval_model"] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } retrieval_model = default_retrieval_model # save dataset dataset = Dataset( tenant_id=tenant_id, - name='', + name="", data_source_type=document_data["data_source"]["type"], indexing_technique=document_data["indexing_technique"], created_by=account.id, embedding_model=embedding_model.model if embedding_model else None, embedding_model_provider=embedding_model.provider if embedding_model else None, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model + retrieval_model=retrieval_model, ) db.session.add(dataset) @@ -1115,236 +1097,259 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun cut_length = 18 cut_name = documents[0].name[:cut_length] - dataset.name = cut_name + '...' - dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name + dataset.name = cut_name + "..." + dataset.description = "useful for when you want to answer queries about the " + documents[0].name db.session.commit() return dataset, documents, batch @classmethod def document_create_args_validate(cls, args: dict): - if 'original_document_id' not in args or not args['original_document_id']: + if "original_document_id" not in args or not args["original_document_id"]: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) else: - if ('data_source' not in args and not args['data_source']) \ - and ('process_rule' not in args and not args['process_rule']): + if ("data_source" not in args and not args["data_source"]) and ( + "process_rule" not in args and not args["process_rule"] + ): raise ValueError("Data source or Process rule is required") else: - if args.get('data_source'): + if args.get("data_source"): DocumentService.data_source_args_validate(args) - if args.get('process_rule'): + if args.get("process_rule"): DocumentService.process_rule_args_validate(args) @classmethod def data_source_args_validate(cls, args: dict): - if 'data_source' not in args or not args['data_source']: + if "data_source" not in args or not args["data_source"]: raise ValueError("Data source is required") - if not isinstance(args['data_source'], dict): + if not isinstance(args["data_source"], dict): raise ValueError("Data source is invalid") - if 'type' not in args['data_source'] or not args['data_source']['type']: + if "type" not in args["data_source"] or not args["data_source"]["type"]: raise ValueError("Data source type is required") - if args['data_source']['type'] not in Document.DATA_SOURCES: + if args["data_source"]["type"] not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if 'info_list' not in args['data_source'] or not args['data_source']['info_list']: + if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: raise ValueError("Data source info is required") - if args['data_source']['type'] == 'upload_file': - if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'file_info_list']: + if args["data_source"]["type"] == "upload_file": + if ( + "file_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["file_info_list"] + ): raise ValueError("File source info is required") - if args['data_source']['type'] == 'notion_import': - if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'notion_info_list']: + if args["data_source"]["type"] == "notion_import": + if ( + "notion_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["notion_info_list"] + ): raise ValueError("Notion source info is required") - if args['data_source']['type'] == 'website_crawl': - if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'website_info_list']: + if args["data_source"]["type"] == "website_crawl": + if ( + "website_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["website_info_list"] + ): raise ValueError("Website source info is required") @classmethod def process_rule_args_validate(cls, args: dict): - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): - if 'info_list' not in args or not args['info_list']: + if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") - if not isinstance(args['info_list'], dict): + if not isinstance(args["info_list"], dict): raise ValueError("Data info is invalid") - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == 'qa_model': - if 'answer' not in args or not args['answer']: + if document.doc_form == "qa_model": + if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") - if not args['answer'].strip(): + if not args["answer"].strip(): raise ValueError("Answer is empty") - if 'content' not in args or not args['content'] or not args['content'].strip(): + if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): - content = args['content'] + content = args["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) - lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1355,25 +1360,25 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] + if document.doc_form == "qa_model": + segment_document.answer = args["answer"] db.session.add(segment_document) db.session.commit() # save vector index try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() @@ -1381,33 +1386,33 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) pre_segment_data_list = [] segment_data_list = [] keywords_list = [] for segment_item in segments: - content = segment_item['content'] + content = segment_item["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: + if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1418,18 +1423,21 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] + if document.doc_form == "qa_model": + segment_document.answer = segment_item["answer"] db.session.add(segment_document) segment_data_list.append(segment_document) pre_segment_data_list.append(segment_document) - keywords_list.append(segment_item['keywords']) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) + else: + keywords_list.append(None) try: # save vector index @@ -1439,19 +1447,19 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if 'enabled' in args and args['enabled'] is not None: - action = args['enabled'] + if "enabled" in args and args["enabled"] is not None: + action = args["enabled"] if segment.enabled != action: if not action: segment.enabled = action @@ -1464,25 +1472,25 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if 'enabled' in args and args['enabled'] is not None: - if not args['enabled']: + if "enabled" in args and args["enabled"] is not None: + if not args["enabled"]: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: - content = args['content'] + content = args["content"] if segment.content == content: - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - if args.get('keywords'): - segment.keywords = args['keywords'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] + if args.get("keywords"): + segment.keywords = args["keywords"] segment.enabled = True segment.disabled_at = None segment.disabled_by = None db.session.add(segment) db.session.commit() # update segment index task - if args['keywords']: + if "keywords" in args: keyword = Keyword(dataset) keyword.delete_by_ids([segment.index_node_id]) document = RAGDocument( @@ -1492,30 +1500,28 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - keyword.add_texts([document], keywords_list=[args['keywords']]) + keyword.add_texts([document], keywords_list=[args["keywords"]]) else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = 'completed' + segment.status = "completed" segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id @@ -1523,18 +1529,18 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == 'qa_model': - segment.answer = args['answer'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] db.session.add(segment) db.session.commit() # update segment vector index - VectorService.update_segment_vector(args['keywords'], segment, dataset) + VectorService.update_segment_vector(args["keywords"], segment, dataset) except Exception as e: logging.exception("update segment index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() @@ -1542,7 +1548,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") @@ -1559,24 +1565,25 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding( - cls, provider_name: str, model_name: str, - collection_type: str = 'dataset' + cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.provider_name == provider_name, - DatasetCollectionBinding.model_name == model_name, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type, + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, model_name=model_name, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type=collection_type + type=collection_type, ) db.session.add(dataset_collection_binding) db.session.commit() @@ -1584,16 +1591,16 @@ def get_dataset_collection_binding( @classmethod def get_dataset_collection_binding_by_id_and_type( - cls, collection_binding_id: str, - collection_type: str = 'dataset' + cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.id == collection_binding_id, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) return dataset_collection_binding @@ -1601,11 +1608,13 @@ def get_dataset_collection_binding_by_id_and_type( class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = db.session.query( - DatasetPermission.account_id, - ).filter( - DatasetPermission.dataset_id == dataset_id - ).all() + user_list_query = ( + db.session.query( + DatasetPermission.account_id, + ) + .filter(DatasetPermission.dataset_id == dataset_id) + .all() + ) user_list = [] for user in user_list_query: @@ -1622,7 +1631,7 @@ def update_partial_member_list(cls, tenant_id, dataset_id, user_list): permission = DatasetPermission( tenant_id=tenant_id, dataset_id=dataset_id, - account_id=user['user_id'], + account_id=user["user_id"], ) permissions.append(permission) @@ -1635,19 +1644,19 @@ def update_partial_member_list(cls, tenant_id, dataset_id, user_list): @classmethod def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): if not user.is_dataset_editor: - raise NoPermissionError('User does not have permission to edit this dataset.') + raise NoPermissionError("User does not have permission to edit this dataset.") if user.is_dataset_operator and dataset.permission != requested_permission: - raise NoPermissionError('Dataset operators cannot change the dataset permissions.') + raise NoPermissionError("Dataset operators cannot change the dataset permissions.") - if user.is_dataset_operator and requested_permission == 'partial_members': + if user.is_dataset_operator and requested_permission == "partial_members": if not requested_partial_member_list: - raise ValueError('Partial member list is required when setting to partial members.') + raise ValueError("Partial member list is required when setting to partial members.") local_member_list = cls.get_dataset_partial_member_list(dataset.id) - request_member_list = [user['user_id'] for user in requested_partial_member_list] + request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): - raise ValueError('Dataset operators cannot change the dataset permissions.') + raise ValueError("Dataset operators cannot change the dataset permissions.") @classmethod def clear_partial_member_list(cls, dataset_id): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index c483d28152c570..ddee52164b746a 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -4,15 +4,12 @@ class EnterpriseRequest: - base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') - secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Enterprise-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 115d0d55232c29..abc01ddf8f58b0 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -2,7 +2,10 @@ class EnterpriseService: - @classmethod def get_info(cls): - return EnterpriseRequest.send_request('GET', '/info') + return EnterpriseRequest.send_request("GET", "/info") + + @classmethod + def get_app_web_sso_enabled(cls, app_code): + return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index e5e4d7e23586d5..c519f0b0e51b68 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum): """ Enum class for custom configuration status. """ - ACTIVE = 'active' - NO_CONFIGURE = 'no-configure' + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" class CustomConfigurationResponse(BaseModel): """ Model class for provider custom configuration response. """ + status: CustomConfigurationStatus @@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel): """ Model class for provider system configuration response. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -46,6 +49,7 @@ class ProviderResponse(BaseModel): """ Model class for provider response. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -67,18 +71,15 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: SimpleProviderEntityResponse @@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): """ Model with provider entity. """ + provider: SimpleProviderEntityResponse def __init__(self, model: ModelWithProviderEntity) -> None: diff --git a/api/services/errors/account.py b/api/services/errors/account.py index ddc2dbdea86265..cae31c5066655a 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError): class RateLimitExceededError(BaseServiceError): pass - diff --git a/api/services/errors/base.py b/api/services/errors/base.py index f5d41e17f1142d..1fed71cf9e380e 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,3 +1,3 @@ class BaseServiceError(Exception): def __init__(self, description: str = None): - self.description = description \ No newline at end of file + self.description = description diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py new file mode 100644 index 00000000000000..e4fac6f7450040 --- /dev/null +++ b/api/services/errors/llm.py @@ -0,0 +1,19 @@ +from typing import Optional + + +class InvokeError(Exception): + """Base class for all LLM exceptions.""" + + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + def __str__(self): + return self.description or self.__class__.__name__ + + +class InvokeRateLimitError(InvokeError): + """Raised when the Invoke returns rate limit error.""" + + description = "Rate Limit Error" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 83e675a9d2e43a..4d5812c6c69bff 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,8 +6,8 @@ class SubscriptionModel(BaseModel): - plan: str = 'sandbox' - interval: str = '' + plan: str = "sandbox" + interval: str = "" class BillingModel(BaseModel): @@ -27,7 +27,7 @@ class FeatureModel(BaseModel): vector_space: LimitationModel = LimitationModel(size=0, limit=5) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) - docs_processing: str = 'standard' + docs_processing: str = "standard" can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False @@ -38,13 +38,13 @@ class FeatureModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_signin_protocol: str = "" sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = '' + sso_enforced_for_web_protocol: str = "" + enable_web_sso_switch_component: bool = False class FeatureService: - @classmethod def get_features(cls, tenant_id: str) -> FeatureModel: features = FeatureModel() @@ -61,6 +61,7 @@ def get_system_features(cls) -> SystemFeatureModel: system_features = SystemFeatureModel() if dify_config.ENTERPRISE_ENABLED: + system_features.enable_web_sso_switch_component = True cls._fulfill_params_from_enterprise(system_features) return system_features @@ -75,44 +76,44 @@ def _fulfill_params_from_env(cls, features: FeatureModel): def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features.billing.enabled = billing_info['enabled'] - features.billing.subscription.plan = billing_info['subscription']['plan'] - features.billing.subscription.interval = billing_info['subscription']['interval'] + features.billing.enabled = billing_info["enabled"] + features.billing.subscription.plan = billing_info["subscription"]["plan"] + features.billing.subscription.interval = billing_info["subscription"]["interval"] - if 'members' in billing_info: - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if "members" in billing_info: + features.members.size = billing_info["members"]["size"] + features.members.limit = billing_info["members"]["limit"] - if 'apps' in billing_info: - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if "apps" in billing_info: + features.apps.size = billing_info["apps"]["size"] + features.apps.limit = billing_info["apps"]["limit"] - if 'vector_space' in billing_info: - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if "vector_space" in billing_info: + features.vector_space.size = billing_info["vector_space"]["size"] + features.vector_space.limit = billing_info["vector_space"]["limit"] - if 'documents_upload_quota' in billing_info: - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if "documents_upload_quota" in billing_info: + features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] + features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] - if 'annotation_quota_limit' in billing_info: - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if "annotation_quota_limit" in billing_info: + features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] + features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] - if 'docs_processing' in billing_info: - features.docs_processing = billing_info['docs_processing'] + if "docs_processing" in billing_info: + features.docs_processing = billing_info["docs_processing"] - if 'can_replace_logo' in billing_info: - features.can_replace_logo = billing_info['can_replace_logo'] + if "can_replace_logo" in billing_info: + features.can_replace_logo = billing_info["can_replace_logo"] - if 'model_load_balancing_enabled' in billing_info: - features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] + if "model_load_balancing_enabled" in billing_info: + features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] - features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] - features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] diff --git a/api/services/file_service.py b/api/services/file_service.py index c686b190fe90d7..5780abb2beee5d 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -17,27 +17,45 @@ from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] -UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] +ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] +UNSTRUCTURED_ALLOWED_EXTENSIONS = [ + "txt", + "markdown", + "md", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "docx", + "csv", + "eml", + "msg", + "pptx", + "ppt", + "xml", + "epub", +] PREVIEW_WORDS_LIMIT = 3000 class FileService: - @staticmethod def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: filename = file.filename - extension = file.filename.split('.')[-1] + extension = file.filename.split(".")[-1] if len(filename) > 200: - filename = filename.split('.')[0][:200] + '.' + extension + filename = filename.split(".")[0][:200] + "." + extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ + allowed_extensions = ( + UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + ) if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() elif only_image and extension.lower() not in IMAGE_EXTENSIONS: @@ -55,7 +73,7 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > file_size_limit: - message = f'File size exceeded. {file_size} > {file_size_limit}' + message = f"File size exceeded. {file_size} > {file_size_limit}" raise FileTooLargeError(message) # user uuid as file name @@ -67,7 +85,7 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo # end_user current_tenant_id = user.tenant_id - file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension + file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, file_content) @@ -81,11 +99,11 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo size=file_size, extension=extension, mime_type=file.mimetype, - created_by_role=('account' if isinstance(user, Account) else 'end_user'), + created_by_role=("account" if isinstance(user, Account) else "end_user"), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest() + hash=hashlib.sha3_256(file_content).hexdigest(), ) db.session.add(upload_file) @@ -99,25 +117,25 @@ def upload_text(text: str, text_name: str) -> UploadFile: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' + file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" # save file to storage - storage.save(file_key, text.encode('utf-8')) + storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( tenant_id=current_user.current_tenant_id, storage_type=dify_config.STORAGE_TYPE, key=file_key, - name=text_name + '.txt', + name=text_name, size=len(text), - extension='txt', - mime_type='text/plain', + extension="txt", + mime_type="text/plain", created_by=current_user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) @@ -127,9 +145,7 @@ def upload_text(text: str, text_name: str) -> UploadFile: @staticmethod def get_file_preview(file_id: str) -> str: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -137,12 +153,12 @@ def get_file_preview(file_id: str) -> str: # extract text from file extension = upload_file.extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) - text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + text = text[0:PREVIEW_WORDS_LIMIT] if text else "" return text @@ -152,9 +168,7 @@ def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> tu if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -170,9 +184,7 @@ def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> tu @staticmethod def get_public_image_preview(file_id: str) -> tuple[Generator, str]: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 69274dff09a9a0..db990648140c54 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,14 +9,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -27,9 +24,9 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode return { "query": { "content": query, - "tsne_position": {'x': 0, 'y': 0}, + "tsne_position": {"x": 0, "y": 0}, }, - "records": [] + "records": [], } start = time.perf_counter() @@ -38,27 +35,28 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode if not retrieval_model: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=cls.escape_query_for_search(query), - top_k=retrieval_model.get('top_k', 2), - score_threshold=retrieval_model['score_threshold'] - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode', None), - weights=retrieval_model.get('weights', None), - ) + all_documents = RetrievalService.retrieve( + retrival_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + top_k=retrieval_model.get("top_k", 2), + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source='hit_testing', - created_by_role='account', - created_by=account.id + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) db.session.add(dataset_query) @@ -71,14 +69,18 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list i = 0 records = [] for document in documents: - index_node_id = document.metadata['doc_id'] - - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == 'completed', - DocumentSegment.index_node_id == index_node_id - ).first() + index_node_id = document.metadata["doc_id"] + + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) if not segment: i += 1 @@ -86,7 +88,7 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list record = { "segment": segment, - "score": document.metadata.get('score', None), + "score": document.metadata.get("score", None), } records.append(record) @@ -97,15 +99,15 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list "query": { "content": query, }, - "records": records + "records": records, } @classmethod def hit_testing_args_check(cls, args): - query = args['query'] + query = args["query"] if not query or len(query) > 250: - raise ValueError('Query is required and cannot exceed 250 characters') + raise ValueError("Query is required and cannot exceed 250 characters") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/message_service.py b/api/services/message_service.py index e310d70d5314e7..ecb121c36e4cc4 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -7,7 +7,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -26,8 +27,14 @@ class MessageService: @classmethod - def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: + def pagination_by_first_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + conversation_id: str, + first_id: Optional[str], + limit: int, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -35,52 +42,69 @@ def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, En return InfiniteScrollPagination(data=[], limit=limit, has_more=False) conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) if first_id: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .first() + ) if not first_message: raise FirstMessageNotExistsError() - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) has_more = False if len(history_messages) == limit: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, - include_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + conversation_id: Optional[str] = None, + include_ids: Optional[list] = None, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -88,9 +112,7 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if conversation_id is not None: conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) base_query = base_query.filter(Message.conversation_id == conversation.id) @@ -104,10 +126,12 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if not last_message: raise LastMessageNotExistsError() - history_messages = base_query.filter( - Message.created_at < last_message.created_at, - Message.id != last_message.id - ).order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() @@ -115,30 +139,22 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if len(history_messages) == limit: current_page_first_message = history_messages[-1] rest_count = base_query.filter( - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id + Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], - rating: Optional[str]) -> MessageFeedback: + def create_feedback( + cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + ) -> MessageFeedback: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback @@ -147,14 +163,14 @@ def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[A elif rating and feedback: feedback.rating = rating elif not rating and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=rating, - from_source=('user' if isinstance(user, EndUser) else 'admin'), + from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) @@ -166,13 +182,17 @@ def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[A @classmethod def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -180,27 +200,22 @@ def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], me return message @classmethod - def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], - message_id: str, invoke_from: InvokeFrom) -> list[Message]: + def get_suggested_questions_after_answer( + cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + ) -> list[Message]: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=message.conversation_id, - user=user + app_model=app_model, conversation_id=message.conversation_id, user=user ) if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() model_manager = ModelManager() @@ -215,24 +230,23 @@ def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Uni if workflow is None: return [] - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.LLM + tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) else: if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id + ) + .first() + ) else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( @@ -248,16 +262,13 @@ def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Uni model_instance = model_manager.get_model_instance( tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], + provider=app_model_config.model_dict["provider"], model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] + model=app_model_config.model_dict["name"], ) # get memory of conversation (read-only) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) histories = memory.get_history_prompt_text( max_token_limit=3000, @@ -266,18 +277,14 @@ def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Uni with measure_time() as timer: questions = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, - histories=histories + tenant_id=app_model.tenant_id, histories=histories ) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) trace_manager.add_trace_task( TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, - message_id=message_id, - suggested_question=questions, - timer=timer + TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer ) ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 09838399961c7f..e7b9422cfe1e08 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -4,6 +4,7 @@ from json import JSONDecodeError from typing import Optional +from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType @@ -22,7 +23,6 @@ class ModelLoadBalancingService: - def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -45,10 +45,7 @@ def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, raise ValueError(f"Provider {provider} does not exist.") # Enable model load balancing - provider_configuration.enable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -69,13 +66,11 @@ def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str raise ValueError(f"Provider {provider} does not exist.") # disable model load balancing - provider_configuration.disable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) - def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ - -> tuple[bool, list[dict]]: + def get_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str + ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. :param tenant_id: workspace id @@ -106,20 +101,24 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, is_load_balancing_enabled = True # Get load balancing configurations - load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).order_by(LoadBalancingModelConfig.created_at).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .order_by(LoadBalancingModelConfig.created_at) + .all() + ) if provider_configuration.custom_configuration.provider: # check if the inherit configuration exists, # inherit is represented for the provider or model custom credentials inherit_config_exists = False for load_balancing_config in load_balancing_configs: - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config_exists = True break @@ -132,7 +131,7 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, else: # move the inherit configuration to the first for i, load_balancing_config in enumerate(load_balancing_configs[:]): - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) @@ -150,7 +149,7 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, provider=provider, model=model, model_type=model_type, - config_id=load_balancing_config.id + config_id=load_balancing_config.id, ) try: @@ -171,32 +170,32 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, if variable in credentials: try: credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa ) except ValueError: pass # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - datas.append({ - 'id': load_balancing_config.id, - 'name': load_balancing_config.name, - 'credentials': credentials, - 'enabled': load_balancing_config.enabled, - 'in_cooldown': in_cooldown, - 'ttl': ttl - }) + datas.append( + { + "id": load_balancing_config.id, + "name": load_balancing_config.name, + "credentials": credentials, + "enabled": load_balancing_config.enabled, + "in_cooldown": in_cooldown, + "ttl": ttl, + } + ) return is_load_balancing_enabled, datas - def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ - -> Optional[dict]: + def get_load_balancing_config( + self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str + ) -> Optional[dict]: """ Get load balancing configuration. :param tenant_id: workspace id @@ -218,14 +217,17 @@ def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, m model_type = ModelType.value_of(model_type) # Get load balancing configurations - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: return None @@ -243,19 +245,19 @@ def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, m # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) return { - 'id': load_balancing_model_config.id, - 'name': load_balancing_model_config.name, - 'credentials': credentials, - 'enabled': load_balancing_model_config.enabled + "id": load_balancing_model_config.id, + "name": load_balancing_model_config.name, + "credentials": credentials, + "enabled": load_balancing_model_config.enabled, } - def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ - -> LoadBalancingModelConfig: + def _init_inherit_config( + self, tenant_id: str, provider: str, model: str, model_type: ModelType + ) -> LoadBalancingModelConfig: """ Initialize the inherit configuration. :param tenant_id: workspace id @@ -270,18 +272,16 @@ def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_ provider_name=provider, model_type=model_type.to_origin_model_type(), model_name=model, - name='__inherit__' + name="__inherit__", ) db.session.add(inherit_config) db.session.commit() return inherit_config - def update_load_balancing_configs(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - configs: list[dict]) -> None: + def update_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + ) -> None: """ Update load balancing configurations. :param tenant_id: workspace id @@ -303,15 +303,18 @@ def update_load_balancing_configs(self, tenant_id: str, model_type = ModelType.value_of(model_type) if not isinstance(configs, list): - raise ValueError('Invalid load balancing configs') + raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + current_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .all() + ) # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -319,25 +322,25 @@ def update_load_balancing_configs(self, tenant_id: str, for config in configs: if not isinstance(config, dict): - raise ValueError('Invalid load balancing config') + raise ValueError("Invalid load balancing config") - config_id = config.get('id') - name = config.get('name') - credentials = config.get('credentials') - enabled = config.get('enabled') + config_id = config.get("id") + name = config.get("name") + credentials = config.get("credentials") + enabled = config.get("enabled") if not name: - raise ValueError('Invalid load balancing config name') + raise ValueError("Invalid load balancing config name") if enabled is None: - raise ValueError('Invalid load balancing config enabled') + raise ValueError("Invalid load balancing config enabled") # is config exists if config_id: config_id = str(config_id) if config_id not in current_load_balancing_configs_dict: - raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + raise ValueError("Invalid load balancing config id: {}".format(config_id)) updated_config_ids.add(config_id) @@ -346,11 +349,11 @@ def update_load_balancing_configs(self, tenant_id: str, # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if credentials: if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -360,7 +363,7 @@ def update_load_balancing_configs(self, tenant_id: str, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, - validate=False + validate=False, ) # update load balancing config @@ -374,19 +377,19 @@ def update_load_balancing_configs(self, tenant_id: str, self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == '__inherit__': - raise ValueError('Invalid load balancing config name') + if name == "__inherit__": + raise ValueError("Invalid load balancing config name") # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if not credentials: - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -395,7 +398,7 @@ def update_load_balancing_configs(self, tenant_id: str, model_type=model_type, model=model, credentials=credentials, - validate=False + validate=False, ) # create load balancing config @@ -405,7 +408,7 @@ def update_load_balancing_configs(self, tenant_id: str, model_type=model_type.to_origin_model_type(), model_name=model, name=name, - encrypted_config=json.dumps(credentials) + encrypted_config=json.dumps(credentials), ) db.session.add(load_balancing_model_config) @@ -419,12 +422,15 @@ def update_load_balancing_configs(self, tenant_id: str, self._clear_credentials_cache(tenant_id, config_id) - def validate_load_balancing_credentials(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - credentials: dict, - config_id: Optional[str] = None) -> None: + def validate_load_balancing_credentials( + self, + tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None, + ) -> None: """ Validate load balancing credentials. :param tenant_id: workspace id @@ -449,14 +455,17 @@ def validate_load_balancing_credentials(self, tenant_id: str, load_balancing_model_config = None if config_id: # Get load balancing config - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: raise ValueError(f"Load balancing config {config_id} does not exist.") @@ -468,16 +477,19 @@ def validate_load_balancing_credentials(self, tenant_id: str, model_type=model_type, model=model, credentials=credentials, - load_balancing_model_config=load_balancing_model_config + load_balancing_model_config=load_balancing_model_config, ) - def _custom_credentials_validate(self, tenant_id: str, - provider_configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, - validate: bool = True) -> dict: + def _custom_credentials_validate( + self, + tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True, + ) -> dict: """ Validate custom credentials. :param tenant_id: workspace id @@ -511,7 +523,7 @@ def _custom_credentials_validate(self, tenant_id: str, for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value - if value == '[__HIDDEN__]' and key in original_credentials: + if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: @@ -520,12 +532,11 @@ def _custom_credentials_validate(self, tenant_id: str, provider=provider_configuration.provider.provider, model_type=model_type, model=model, - credentials=credentials + credentials=credentials, ) else: credentials = model_provider_factory.provider_credentials_validate( - provider=provider_configuration.provider.provider, - credentials=credentials + provider=provider_configuration.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -534,8 +545,9 @@ def _custom_credentials_validate(self, tenant_id: str, return credentials - def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ - -> ModelCredentialSchema | ProviderCredentialSchema: + def _get_credential_schema( + self, provider_configuration: ProviderConfiguration + ) -> ModelCredentialSchema | ProviderCredentialSchema: """ Get form schemas. :param provider_configuration: provider configuration @@ -557,9 +569,7 @@ def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: :return: """ provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=config_id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL ) provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 385af685f98f44..c0f3c407625415 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -30,6 +30,7 @@ class ModelProviderService: """ Model Provider Service """ + def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -72,8 +73,8 @@ def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, current_quota_type=provider_configuration.system_configuration.current_quota_type, - quota_configurations=provider_configuration.system_configuration.quota_configurations - ) + quota_configurations=provider_configuration.system_configuration.quota_configurations, + ), ) provider_responses.append(provider_response) @@ -94,9 +95,9 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models( - provider=provider - )] + return [ + ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ] def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: """ @@ -194,13 +195,12 @@ def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, # Get model custom credentials from ProviderModel if exists return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - obfuscated=True + model_type=ModelType.value_of(model_type), model=model, obfuscated=True ) - def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def model_credentials_validate( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ validate model credentials. @@ -221,13 +221,12 @@ def model_credentials_validate(self, tenant_id: str, provider: str, model_type: # Validate model credentials provider_configuration.custom_model_credentials_validate( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def save_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ save model credentials. @@ -248,9 +247,7 @@ def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, # Add or update custom model credentials provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: @@ -272,10 +269,7 @@ def remove_model_credentials(self, tenant_id: str, provider: str, model_type: st raise ValueError(f"Provider {provider} does not exist.") # Remove custom model credentials - provider_configuration.delete_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model - ) + provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -289,9 +283,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - models = provider_configurations.get_models( - model_type=ModelType.value_of(model_type) - ) + models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider provider_models = {} @@ -322,16 +314,19 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, - models=[ProviderModelWithStatusEntity( - model=model.model, - label=model.label, - model_type=model.model_type, - features=model.features, - fetch_from=model.fetch_from, - model_properties=model.model_properties, - status=model.status, - load_balancing_enabled=model.load_balancing_enabled - ) for model in models] + models=[ + ProviderModelWithStatusEntity( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + load_balancing_enabled=model.load_balancing_enabled, + ) + for model in models + ], ) ) @@ -360,19 +355,13 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) - model_type_instance = cast(LargeLanguageModel, model_type_instance) # fetch credentials - credentials = provider_configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model - ) + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) if not credentials: return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules( - model=model, - credentials=credentials - ) + return model_type_instance.get_parameter_rules(model=model, credentials=credentials) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -383,22 +372,26 @@ def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Op :return: """ model_type_enum = ModelType.value_of(model_type) - result = self.provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type_enum - ) - - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + try: + return ( + DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types, + ), + ) + if result + else None ) - ) if result else None + except Exception as e: + logger.info(f"get_default_model_of_model_type error: {e}") + return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: """ @@ -412,13 +405,12 @@ def update_default_model_of_model_type(self, tenant_id: str, model_type: str, pr """ model_type_enum = ModelType.value_of(model_type) self.provider_manager.update_default_model_record( - tenant_id=tenant_id, - model_type=model_type_enum, - provider=provider, - model=model + tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) - def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]: + def get_model_provider_icon( + self, provider: str, icon_type: str, lang: str + ) -> tuple[Optional[bytes], Optional[str]]: """ get model provider icon. @@ -430,11 +422,11 @@ def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> t provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() - if icon_type.lower() == 'icon_small': + if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_small.zh_Hans else: file_name = provider_schema.icon_small.en_US @@ -442,13 +434,15 @@ def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> t if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US root_path = current_app.root_path - provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/'))) + provider_instance_path = os.path.dirname( + os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) + ) file_path = os.path.join(provider_instance_path, "_assets") file_path = os.path.join(file_path, file_name) @@ -456,10 +450,10 @@ def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> t return None, None mimetype, _ = mimetypes.guess_type(file_path) - mimetype = mimetype or 'application/octet-stream' + mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: byte_data = f.read() return byte_data, mimetype @@ -505,10 +499,7 @@ def enable_model(self, tenant_id: str, provider: str, model: str, model_type: st raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.enable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -529,78 +520,49 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.disable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/apply' + api_url = api_base_url + "/api/v1/providers/apply" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider}) + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") - if response.json()["code"] != 'success': - raise ValueError( - f"error: {response.json()['message']}" - ) + if response.json()["code"] != "success": + raise ValueError(f"error: {response.json()['message']}") rst = response.json() - if rst['type'] == 'redirect': - return { - 'type': rst['type'], - 'redirect_url': rst['redirect_url'] - } + if rst["type"] == "redirect": + return {"type": rst["type"], "redirect_url": rst["redirect_url"]} else: - return { - 'type': rst['type'], - 'result': 'success' - } + return {"type": rst["type"], "result": "success"} def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/qualification-verify' + api_url = api_base_url + "/api/v1/providers/qualification-verify" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - json_data = {'workspace_id': tenant_id, 'provider_name': provider} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + json_data = {"workspace_id": tenant_id, "provider_name": provider} if token: - json_data['token'] = token - response = requests.post(api_url, headers=headers, - json=json_data) + json_data["token"] = token + response = requests.post(api_url, headers=headers, json=json_data) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") rst = response.json() - if rst["code"] != 'success': - raise ValueError( - f"error: {rst['message']}" - ) + if rst["code"] != "success": + raise ValueError(f"error: {rst['message']}") - data = rst['data'] - if data['qualified'] is True: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': True - } + data = rst["data"] + if data["qualified"] is True: + return {"result": "success", "provider_name": provider, "flag": True} else: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': False, - 'reason': data['reason'] - } + return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index d472f8cfbca435..dfb21e767fc9b9 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -4,17 +4,18 @@ class ModerationService: - def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: app_model_config: AppModelConfig = None - app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) if not app_model_config: raise ValueError("app model config not found") - name = app_model_config.sensitive_word_avoidance_dict['type'] - config = app_model_config.sensitive_word_avoidance_dict['config'] + name = app_model_config.sensitive_word_avoidance_dict["type"] + config = app_model_config.sensitive_word_avoidance_dict["config"] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) return moderation.moderation_for_outputs(text) diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 39f249dc24eb09..8c8b64bcd5d344 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -4,15 +4,12 @@ class OperationService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -22,11 +19,11 @@ def _send_request(cls, method, endpoint, json=None, params=None): @classmethod def record_utm(cls, tenant_id: str, utm_info: dict): params = { - 'tenant_id': tenant_id, - 'utm_source': utm_info.get('utm_source', ''), - 'utm_medium': utm_info.get('utm_medium', ''), - 'utm_campaign': utm_info.get('utm_campaign', ''), - 'utm_content': utm_info.get('utm_content', ''), - 'utm_term': utm_info.get('utm_term', '') + "tenant_id": tenant_id, + "utm_source": utm_info.get("utm_source", ""), + "utm_medium": utm_info.get("utm_medium", ""), + "utm_campaign": utm_info.get("utm_campaign", ""), + "utm_content": utm_info.get("utm_content", ""), + "utm_term": utm_info.get("utm_term", ""), } - return cls._send_request('POST', '/tenant_utms', params=params) + return cls._send_request("POST", "/tenant_utms", params=params) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index ffc12a9acdb42c..35aa6817e11cc8 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -12,20 +12,29 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None # decrypt_token and obfuscated_token tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id - decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) - decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( + tenant_id, tracing_provider, trace_config_data.tracing_config + ) + new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) - trace_config_data.tracing_config = decrypt_tracing_config + if tracing_provider == "langfuse" and ( + "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") + ): + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_key": project_key}) + trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @classmethod @@ -37,11 +46,13 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider != None: + if tracing_provider not in provider_config_map.keys() and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} - config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['other_keys'] + config_class, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["other_keys"], + ) default_config_instance = config_class(**tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": @@ -51,10 +62,15 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): return {"error": "Invalid Credentials"} + # get project key + project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + # check if trace config already exists - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if trace_config_data: return None @@ -62,6 +78,8 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) + if tracing_provider == "langfuse" and project_key: + tracing_config["project_key"] = project_key trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, @@ -85,9 +103,11 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + current_trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not current_trace_config: return None @@ -117,9 +137,11 @@ def delete_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config: return None diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 1c1c5be17c64b3..10abf0a764c75c 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -16,7 +16,6 @@ class RecommendedAppService: - builtin_data: Optional[dict] = None @classmethod @@ -27,21 +26,21 @@ def get_recommended_apps_and_categories(cls, language: str) -> dict: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_apps_from_builtin(language) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_apps_from_db(language) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_apps_from_builtin(language) else: - raise ValueError(f'invalid fetch recommended apps mode: {mode}') + raise ValueError(f"invalid fetch recommended apps mode: {mode}") - if not result.get('recommended_apps') and language != 'en-US': - result = cls._fetch_recommended_apps_from_builtin('en-US') + if not result.get("recommended_apps") and language != "en-US": + result = cls._fetch_recommended_apps_from_builtin("en-US") return result @@ -52,16 +51,18 @@ def _fetch_recommended_apps_from_db(cls, language: str) -> dict: :param language: language :return: """ - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == language - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) if len(recommended_apps) == 0: - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == languages[0] - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) categories = set() recommended_apps_result = [] @@ -75,28 +76,28 @@ def _fetch_recommended_apps_from_db(cls, language: str) -> dict: continue recommended_app_result = { - 'id': recommended_app.id, - 'app': { - 'id': app.id, - 'name': app.name, - 'mode': app.mode, - 'icon': app.icon, - 'icon_background': app.icon_background + "id": recommended_app.id, + "app": { + "id": app.id, + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background, }, - 'app_id': recommended_app.app_id, - 'description': site.description, - 'copyright': site.copyright, - 'privacy_policy': site.privacy_policy, - 'custom_disclaimer': site.custom_disclaimer, - 'category': recommended_app.category, - 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, } recommended_apps_result.append(recommended_app_result) categories.add(recommended_app.category) # add category to categories - return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -106,16 +107,16 @@ def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps?language={language}' + url = f"{domain}/apps?language={language}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: - raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}') + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") result = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) - + return result @classmethod @@ -126,7 +127,7 @@ def _fetch_recommended_apps_from_builtin(cls, language: str) -> dict: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('recommended_apps', {}).get(language) + return builtin_data.get("recommended_apps", {}).get(language) @classmethod def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: @@ -136,18 +137,18 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_app_detail_from_dify_official(app_id) except Exception as e: - logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_app_detail_from_builtin(app_id) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_app_detail_from_db(app_id) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_app_detail_from_builtin(app_id) else: - raise ValueError(f'invalid fetch recommended app detail mode: {mode}') + raise ValueError(f"invalid fetch recommended app detail mode: {mode}") return result @@ -159,7 +160,7 @@ def _fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Option :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps/{app_id}' + url = f"{domain}/apps/{app_id}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None @@ -174,10 +175,11 @@ def _fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: :return: """ # is in public recommended list - recommended_app = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.app_id == app_id - ).first() + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) if not recommended_app: return None @@ -188,12 +190,12 @@ def _fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: return None return { - 'id': app_model.id, - 'name': app_model.name, - 'icon': app_model.icon, - 'icon_background': app_model.icon_background, - 'mode': app_model.mode, - 'export_data': AppDslService.export_dsl(app_model=app_model) + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), } @classmethod @@ -204,7 +206,7 @@ def _fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dic :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('app_details', {}).get(app_id) + return builtin_data.get("app_details", {}).get(app_id) @classmethod def _get_builtin_data(cls) -> dict: @@ -216,7 +218,7 @@ def _get_builtin_data(cls) -> dict: return cls.builtin_data root_path = current_app.root_path - with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f: + with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: json_data = f.read() data = json.loads(json_data) cls.builtin_data = data @@ -229,27 +231,24 @@ def fetch_all_recommended_apps_and_export_datas(cls): Fetch all recommended apps and export datas :return: """ - templates = { - "recommended_apps": {}, - "app_details": {} - } + templates = {"recommended_apps": {}, "app_details": {}} for language in languages: try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.") continue - templates['recommended_apps'][language] = result + templates["recommended_apps"][language] = result - for recommended_app in result.get('recommended_apps'): - app_id = recommended_app.get('app_id') + for recommended_app in result.get("recommended_apps"): + app_id = recommended_app.get("app_id") # get app detail app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) if not app_detail: continue - templates['app_details'][app_id] = app_detail + templates["app_details"][app_id] = app_detail return templates diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index f1113c1505e197..9fe3cecce7546d 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -10,46 +10,48 @@ class SavedMessageService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int) -> InfiniteScrollPagination: - saved_messages = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).order_by(SavedMessage.created_at.desc()).all() + def pagination_by_last_id( + cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + ) -> InfiniteScrollPagination: + saved_messages = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .order_by(SavedMessage.created_at.desc()) + .all() + ) message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( - app_model=app_model, - user=user, - last_id=last_id, - limit=limit, - include_ids=message_ids + app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if saved_message: return - message = MessageService.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(saved_message) @@ -57,12 +59,16 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if not saved_message: return diff --git a/api/services/tag_service.py b/api/services/tag_service.py index d6eba38fbdef3e..0c17485a9f613a 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -12,38 +12,32 @@ class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list: - query = db.session.query( - Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count') - ).outerjoin( - TagBinding, Tag.id == TagBinding.tag_id - ).filter( - Tag.type == tag_type, - Tag.tenant_id == current_tenant_id + query = ( + db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) + .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%'))) - query = query.group_by( - Tag.id - ) + query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.group_by(Tag.id) results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: - tags = db.session.query(Tag).filter( - Tag.id.in_(tag_ids), - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) if not tags: return [] tag_ids = [tag.id for tag in tags] - tag_bindings = db.session.query( - TagBinding.target_id - ).filter( - TagBinding.tag_id.in_(tag_ids), - TagBinding.tenant_id == current_tenant_id - ).all() + tag_bindings = ( + db.session.query(TagBinding.target_id) + .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .all() + ) if not tag_bindings: return [] results = [tag_binding.target_id for tag_binding in tag_bindings] @@ -51,27 +45,28 @@ def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: li @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == target_id, - TagBinding.tenant_id == current_tenant_id, - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == target_id, + TagBinding.tenant_id == current_tenant_id, + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) + .all() + ) return tags if tags else [] - @staticmethod def save_tags(args: dict) -> Tag: tag = Tag( id=str(uuid.uuid4()), - name=args['name'], - type=args['type'], + name=args["name"], + type=args["type"], created_by=current_user.id, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) db.session.add(tag) db.session.commit() @@ -82,7 +77,7 @@ def update_tags(args: dict, tag_id: str) -> Tag: tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") - tag.name = args['name'] + tag.name = args["name"] db.session.commit() return tag @@ -107,20 +102,21 @@ def delete_tag(tag_id: str): @staticmethod def save_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding - for tag_id in args['tag_ids']: - tag_binding = db.session.query(TagBinding).filter( - TagBinding.tag_id == tag_id, - TagBinding.target_id == args['target_id'] - ).first() + for tag_id in args["tag_ids"]: + tag_binding = ( + db.session.query(TagBinding) + .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .first() + ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args['target_id'], + target_id=args["target_id"], tenant_id=current_user.current_tenant_id, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(new_tag_binding) db.session.commit() @@ -128,34 +124,34 @@ def save_tag_binding(args): @staticmethod def delete_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter( - TagBinding.target_id == args['target_id'], - TagBinding.tag_id == (args['tag_id']) - ).first() + tag_bindings = ( + db.session.query(TagBinding) + .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .first() + ) if tag_bindings: db.session.delete(tag_bindings) db.session.commit() - - @staticmethod def check_target_exists(type: str, target_id: str): - if type == 'knowledge': - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == current_user.current_tenant_id, - Dataset.id == target_id - ).first() + if type == "knowledge": + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .first() + ) if not dataset: raise NotFound("Dataset not found") - elif type == 'app': - app = db.session.query(App).filter( - App.tenant_id == current_user.current_tenant_id, - App.id == target_id - ).first() + elif type == "app": + app = ( + db.session.query(App) + .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .first() + ) if not app: raise NotFound("App not found") else: raise NotFound("Invalid binding type") - diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ecc065d521c703..3ded9c0989b53c 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -29,111 +29,107 @@ class ApiToolManageService: @staticmethod def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ - parse api schema to tool bundle + parse api schema to tool bundle """ try: warnings = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') - + raise ValueError(f"invalid schema: {str(e)}") + credentials_schema = [ ToolProviderCredentials( - name='auth_type', + name="auth_type", type=ToolProviderCredentials.CredentialsType.SELECT, required=True, - default='none', + default="none", options=[ - ToolCredentialsOption(value='none', label=I18nObject( - en_US='None', - zh_Hans='无' - )), - ToolCredentialsOption(value='api_key', label=I18nObject( - en_US='Api Key', - zh_Hans='Api Key' - )), + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ], - placeholder=I18nObject( - en_US='Select auth type', - zh_Hans='选择认证方式' - ) + placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), ToolProviderCredentials( - name='api_key_header', + name="api_key_header", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key header', - zh_Hans='输入 api key header,如:X-API-KEY' - ), - default='api_key', - help=I18nObject( - en_US='HTTP header name for api key', - zh_Hans='HTTP 头部字段名,用于传递 api key' - ) + placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), + default="api_key", + help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), ), ToolProviderCredentials( - name='api_key_value', + name="api_key_value", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key', - zh_Hans='输入 api key' - ), - default='' + placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), + default="", ), ] - return jsonable_encoder({ - 'schema_type': schema_type, - 'parameters_schema': tool_bundles, - 'credentials_schema': credentials_schema, - 'warning': warnings - }) + return jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: """ - convert schema to tool bundles + convert schema to tool bundles - :return: the list of tool bundles, description + :return: the list of tool bundles, description """ try: tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) return tool_bundles except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def create_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - create api tool provider + create api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is not None: - raise ValueError(f'provider {provider_name} already exists') + raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + if len(tool_bundles) > 100: - raise ValueError('the number of apis should be less than 100') + raise ValueError("the number of apis should be less than 100") # create db provider db_provider = ApiToolProvider( @@ -142,19 +138,19 @@ def create_api_tool_provider( name=provider_name, icon=json.dumps(icon), schema=schema, - description=extra_info.get('description', ''), + description=extra_info.get("description", ""), schema_type_str=schema_type, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer + custom_disclaimer=custom_disclaimer, ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -172,14 +168,12 @@ def create_api_tool_provider( # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider_remote_schema( - user_id: str, tenant_id: str, url: str - ): + def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): """ - get api tool provider remote schema + get api tool provider remote schema """ headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", @@ -189,84 +183,98 @@ def get_api_tool_provider_remote_schema( try: response = get(url, headers=headers, timeout=10) if response.status_code != 200: - raise ValueError(f'Got status code {response.status_code}') + raise ValueError(f"Got status code {response.status_code}") schema = response.text # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception as e: logger.error(f"parse api schema error: {str(e)}") - raise ValueError('invalid schema, please check the url you provided') - - return { - 'schema': schema - } + raise ValueError("invalid schema, please check the url you provided") + + return {"schema": schema} @staticmethod - def list_api_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list api tool provider tools + list api tool provider tools """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') - + raise ValueError(f"you have not added provider {provider}") + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) - + return [ ToolTransformService.tool_to_user_tool( tool_bundle, labels=labels, - ) for tool_bundle in provider.tools + ) + for tool_bundle in provider.tools ] @staticmethod def update_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + original_provider: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - update api tool provider + update api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .first() + ) if provider is None: - raise ValueError(f'api provider {provider_name} does not exists') + raise ValueError(f"api provider {provider_name} does not exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + # update db provider provider.name = provider_name provider.icon = json.dumps(icon) provider.schema = schema - provider.description = extra_info.get('description', '') + provider.description = extra_info.get("description", "") provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(provider, auth_type) @@ -295,84 +303,91 @@ def update_api_tool_provider( # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def delete_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider( - user_id: str, tenant_id: str, provider: str - ): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): """ - get api tool provider + get api tool provider """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) - + @staticmethod def test_api_tool_preview( - tenant_id: str, + tenant_id: str, provider_name: str, - tool_name: str, - credentials: dict, - parameters: dict, - schema_type: str, - schema: str + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str, ): """ - test api tool before adding api tool provider + test api tool before adding api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema_type}') - + raise ValueError(f"invalid schema type {schema_type}") + try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) except Exception as e: - raise ValueError('invalid schema') - + raise ValueError("invalid schema") + # get tool bundle tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) if tool_bundle is None: - raise ValueError(f'invalid tool name {tool_name}') - - db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + raise ValueError(f"invalid tool name {tool_name}") + + db_provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if not db_provider: # create a fake db provider db_provider = ApiToolProvider( - tenant_id='', user_id='', name='', icon='', + tenant_id="", + user_id="", + name="", + icon="", schema=schema, - description='', + description="", schema_type_str=ApiProviderSchemaType.OPENAPI.value, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -381,10 +396,7 @@ def test_api_tool_preview( # decrypt credentials if db_provider.id: - tool_configuration = ToolConfigurationManager( - tenant_id=tenant_id, - provider_controller=provider_controller - ) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) # check if the credential has changed, save the original credential masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) @@ -396,27 +408,27 @@ def test_api_tool_preview( provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) result = tool.validate_credentials(credentials, parameters) except Exception as e: - return { 'error': str(e) } - - return { 'result': result or 'empty response' } - + return {"error": str(e)} + + return {"result": result or "empty response"} + @staticmethod - def list_api_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list api tools + list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + ) result: list[UserToolProvider] = [] @@ -425,26 +437,21 @@ def list_api_tools( provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller, - db_provider=provider, - decrypt_credentials=True + provider_controller, db_provider=provider, decrypt_credentials=True ) user_provider.labels = labels # add icon ToolTransformService.repack_provider(user_provider) - tools = provider_controller.get_tools( - user_id=user_id, tenant_id=tenant_id - ) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) for tool in tools: - user_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_provider.original_credentials, - labels=labels - )) + user_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + ) + ) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ea6ecf0c6985f5..dc8cebb587a2d8 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,8 @@ import json import logging +from configs import dify_config +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError @@ -18,21 +20,25 @@ class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list builtin tool provider tools + list builtin tool provider tools """ provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_provider_configurations = ToolConfigurationManager( + tenant_id=tenant_id, provider_controller=provider_controller + ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) credentials = {} if builtin_provider is not None: @@ -42,47 +48,47 @@ def list_builtin_tool_provider_tools( result = [] for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, - tenant_id=tenant_id, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + result.append( + ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) return result - + @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): + def list_builtin_provider_credentials_schema(provider_name): """ - list builtin provider credentials schema + list builtin provider credentials schema - :return: the list of tool providers + :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) + return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): + def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): """ - update builtin tool provider + update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) - try: + try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') + raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: @@ -119,23 +125,25 @@ def update_builtin_tool_provider( # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): + def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): """ - get builtin tool provider credentials + get builtin tool provider credentials """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) if provider is None: return {} - + provider_controller = ToolManager.get_builtin_provider(provider.provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -143,20 +151,22 @@ def get_builtin_tool_provider_credentials( return credentials @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() @@ -165,48 +175,55 @@ def delete_builtin_tool_provider( tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): + def get_builtin_tool_provider_icon(provider: str): """ - get tool provider icon and it's mimetype + get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: + with open(icon_path, "rb") as f: icon_bytes = f.read() return icon_bytes, mime_type - + @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list builtin tools + list builtin tools """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers() # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] + ) # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + find_provider = lambda provider: next( + filter(lambda db_provider: db_provider.provider == provider, db_providers), None + ) result: list[UserToolProvider] = [] for provider_controller in provider_controllers: try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue + # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True + decrypt_credentials=True, ) # add icon @@ -214,16 +231,17 @@ def list_builtin_tools( tools = provider_controller.get_tools() for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + user_builtin_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) result.append(user_builtin_provider) except Exception as e: raise e return BuiltinToolProviderSort.sort(result) - \ No newline at end of file diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py index 8a6aa025f27c6b..35e58b5adec58f 100644 --- a/api/services/tools/tool_labels_service.py +++ b/api/services/tools/tool_labels_service.py @@ -5,4 +5,4 @@ class ToolLabelsService: @classmethod def list_tool_labels(cls) -> list[ToolLabel]: - return default_tool_labels \ No newline at end of file + return default_tool_labels diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 76d2f53ae86535..1c67f7648ca99f 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -11,13 +11,11 @@ class ToolCommonService: @staticmethod def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): """ - list tool providers + list tool providers - :return: the list of tool providers + :return: the list of tool providers """ - providers = ToolManager.user_list_providers( - user_id, tenant_id, typ - ) + providers = ToolManager.user_list_providers(user_id, tenant_id, typ) # add icon for provider in providers: @@ -26,4 +24,3 @@ def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeL result = [provider.to_dict() for provider in providers] return result - \ No newline at end of file diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index cfce3fbd019a2c..6fb0f2f517376d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -22,46 +22,39 @@ logger = logging.getLogger(__name__) + class ToolTransformService: @staticmethod def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: """ - get tool provider icon url + get tool provider icon url """ - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/") - + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" + if provider_type == ToolProviderType.BUILT_IN.value: - return url_prefix + 'builtin/' + provider_name + '/icon' + return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: try: return json.loads(icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - - return '' - + return {"background": "#252525", "content": "\ud83d\ude01"} + + return "" + @staticmethod def repack_provider(provider: Union[dict, UserToolProvider]): """ - repack provider + repack provider - :param provider: the provider dict + :param provider: the provider dict """ - if isinstance(provider, dict) and 'icon' in provider: - provider['icon'] = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider['type'], - provider_name=provider['name'], - icon=provider['icon'] + if isinstance(provider, dict) and "icon" in provider: + provider["icon"] = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, - provider_name=provider.name, - icon=provider.icon + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) @staticmethod @@ -92,14 +85,13 @@ def builtin_provider_to_user_provider( masked_credentials={}, is_team_authorization=False, tools=[], - labels=provider_controller.tool_labels + labels=provider_controller.tool_labels, ) # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = \ - ToolProviderCredentials.CredentialsType.default(value.type) + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) # check if the provider need credentials if not provider_controller.need_credentials: @@ -113,8 +105,7 @@ def builtin_provider_to_user_provider( # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -124,7 +115,7 @@ def builtin_provider_to_user_provider( result.original_credentials = decrypted_credentials return result - + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -135,25 +126,23 @@ def api_provider_to_controller( # package tool provider controller controller = ApiToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + auth_type=ApiProviderAuthType.API_KEY + if db_provider.credentials["auth_type"] == "api_key" + else ApiProviderAuthType.NONE, ) return controller - + @staticmethod - def workflow_provider_to_controller( - db_provider: WorkflowToolProvider - ) -> WorkflowToolProviderController: + def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: """ convert provider controller to provider """ return WorkflowToolProviderController.from_db(db_provider) - + @staticmethod def workflow_provider_to_user_provider( - provider_controller: WorkflowToolProviderController, - labels: list[str] = None + provider_controller: WorkflowToolProviderController, labels: list[str] = None ): """ convert provider controller to user provider @@ -175,7 +164,7 @@ def workflow_provider_to_user_provider( masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) @staticmethod @@ -183,16 +172,16 @@ def api_provider_to_user_provider( provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, - labels: list[str] = None + labels: list[str] = None, ) -> UserToolProvider: """ convert provider controller to user provider """ - username = 'Anonymous' + username = "Anonymous" try: username = db_provider.user.name except Exception as e: - logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') + logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") # add provider into providers credentials = db_provider.credentials result = UserToolProvider( @@ -212,14 +201,13 @@ def api_provider_to_user_provider( masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) if decrypt_credentials: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials @@ -229,23 +217,25 @@ def api_provider_to_user_provider( result.masked_credentials = masked_credentials return result - + @staticmethod def tool_to_user_tool( - tool: Union[ApiToolBundle, WorkflowTool, Tool], - credentials: dict = None, + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: dict = None, tenant_id: str = None, - labels: list[str] = None + labels: list[str] = None, ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) # get tool parameters parameters = tool.parameters or [] @@ -270,20 +260,14 @@ def tool_to_user_tool( label=tool.identity.label, description=tool.description.human, parameters=current_parameters, - labels=labels + labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, - label=I18nObject( - en_US=tool.operation_id, - zh_Hans=tool.operation_id - ), - description=I18nObject( - en_US=tool.summary or '', - zh_Hans=tool.summary or '' - ), + label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, - labels=labels - ) \ No newline at end of file + labels=labels, + ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index e89d94160c7c04..3830e75339eebc 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -19,10 +19,21 @@ class WorkflowToolManageService: """ Service class for managing workflow tools. """ + @classmethod - def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, - label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def create_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_app_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Create a workflow tool. :param user_id: the user id @@ -32,32 +43,34 @@ def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str :param description: the description :param parameters: the parameters :param privacy_policy: the privacy policy + :param labels: labels :return: the created tool """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') - - app: App = db.session.query(App).filter( - App.id == workflow_app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") + + app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: - raise ValueError(f'App {workflow_app_id} not found') - + raise ValueError(f"App {workflow_app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_app_id}') - + raise ValueError(f"Workflow not found for app {workflow_app_id}") + workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -75,58 +88,76 @@ def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - + db.session.add(workflow_tool_provider) db.session.commit() - return { - 'result': 'success' - } - + return {"result": "success"} @classmethod - def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, - name: str, label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def update_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_tool_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Update a workflow tool. :param user_id: the user id :param tenant_id: the tenant id - :param tool: the tool + :param workflow_tool_id: workflow tool id + :param name: name + :param label: label + :param icon: icon + :param description: description + :param parameters: parameters + :param privacy_policy: privacy policy + :param labels: labels :return: the updated tool """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} already exists') - - workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + raise ValueError(f"Tool with name {name} already exists") + + workflow_tool_provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if workflow_tool_provider is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - app: App = db.session.query(App).filter( - App.id == workflow_tool_provider.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + app: App = ( + db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + ) if app is None: - raise ValueError(f'App {workflow_tool_provider.app_id} not found') - + raise ValueError(f"App {workflow_tool_provider.app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') - + raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) @@ -146,13 +177,10 @@ def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: st if labels is not None: ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), - labels + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: @@ -162,9 +190,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id - ).all() + db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() tools = [] for provider in db_tools: @@ -180,14 +206,12 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo for tool in tools: user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=tool, - labels=labels.get(tool.provider_id, []) + provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=labels.get(tool.provider_id, []) + tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) ) ] result.append(user_tool_provider) @@ -203,15 +227,12 @@ def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: st :param workflow_app_id: the workflow app id """ db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() db.session.commit() - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: @@ -222,40 +243,37 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy, + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: """ @@ -265,40 +283,37 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == workflow_app_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_app_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_app_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: """ @@ -308,19 +323,19 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') + raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ) - ] \ No newline at end of file + ] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 232d2943256cad..3c67351335359d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -7,10 +7,10 @@ class VectorService: - @classmethod - def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], - segments: list[DocumentSegment], dataset: Dataset): + def create_segments_vector( + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + ): documents = [] for segment in segments: document = Document( @@ -20,14 +20,12 @@ def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # save vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.add_texts(documents, duplicate_check=True) # save keyword index @@ -50,13 +48,11 @@ def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentS "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # update vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) vector.add_texts([document], duplicate_check=True) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index cba048ccdbbbf1..d7ccc964cb70f8 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,17 +11,29 @@ class WebConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + pinned: Optional[bool] = None, + sort_by="-updated_at", + ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: - pinned_conversations = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).order_by(PinnedConversation.created_at.desc()).all() + pinned_conversations = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .order_by(PinnedConversation.created_at.desc()) + .all() + ) pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] if pinned: include_ids = pinned_conversation_ids @@ -36,31 +48,34 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, + sort_by=sort_by, ) @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if pinned_conversation: return conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=conversation_id, - user=user + app_model=app_model, conversation_id=conversation_id, user=user ) pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(pinned_conversation) @@ -68,12 +83,16 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if not pinned_conversation: return diff --git a/api/services/website_service.py b/api/services/website_service.py index c166b01237b6c4..6dff35d63f1925 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -11,161 +11,126 @@ class WebsiteService: - @classmethod def document_create_args_validate(cls, args: dict): - if 'url' not in args or not args['url']: - raise ValueError('url is required') - if 'options' not in args or not args['options']: - raise ValueError('options is required') - if 'limit' not in args['options'] or not args['options']['limit']: - raise ValueError('limit is required') + if "url" not in args or not args["url"]: + raise ValueError("url is required") + if "options" not in args or not args["options"]: + raise ValueError("options is required") + if "limit" not in args["options"] or not args["options"]["limit"]: + raise ValueError("limit is required") @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get('provider') - url = args.get('url') - options = args.get('options') - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + provider = args.get("provider") + url = args.get("url") + options = args.get("options") + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - crawl_sub_pages = options.get('crawl_sub_pages', False) - only_main_content = options.get('only_main_content', False) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + crawl_sub_pages = options.get("crawl_sub_pages", False) + only_main_content = options.get("only_main_content", False) if not crawl_sub_pages: params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "limit": 1, - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } else: - includes = options.get('includes').split(',') if options.get('includes') else [] - excludes = options.get('excludes').split(',') if options.get('excludes') else [] + includes = options.get("includes").split(",") if options.get("includes") else [] + excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": includes if includes else [], "excludes": excludes if excludes else [], "generateImgAltText": True, - "limit": options.get('limit', 1), - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "limit": options.get("limit", 1), + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } - if options.get('max_depth'): - params['crawlerOptions']['maxDepth'] = options.get('max_depth') + if options.get("max_depth"): + params["crawlerOptions"]["maxDepth"] = options.get("max_depth") job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f'website_crawl_{job_id}' + website_crawl_time_cache_key = f"website_crawl_{job_id}" time = str(datetime.datetime.now().timestamp()) redis_client.setex(website_crawl_time_cache_key, 3600, time) - return { - 'status': 'active', - 'job_id': job_id - } + return {"status": "active", "job_id": job_id} else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) crawl_status_data = { - 'status': result.get('status', 'active'), - 'job_id': job_id, - 'total': result.get('total', 0), - 'current': result.get('current', 0), - 'data': result.get('data', []) + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), } - if crawl_status_data['status'] == 'completed': - website_crawl_time_cache_key = f'website_crawl_{job_id}' + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" start_time = redis_client.get(website_crawl_time_cache_key) if start_time: end_time = datetime.datetime.now().timestamp() time_consuming = abs(end_time - float(start_time)) - crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" redis_client.delete(website_crawl_time_cache_key) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': - file_key = 'website_files/' + job_id + '.txt' + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): data = storage.load_once(file_key) if data: - data = json.loads(data.decode('utf-8')) + data = json.loads(data.decode("utf-8")) else: # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) - if result.get('status') != 'completed': - raise ValueError('Crawl job is not completed') - data = result.get('data') + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + data = result.get("data") if data: for item in data: - if item.get('source_url') == url: + if item.get("source_url") == url: return item return None else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - params = { - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } - } + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}} result = firecrawl_app.scrape_url(url, params) return result else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 06b129be691010..4b845be2f4d01f 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,7 +6,6 @@ DatasetRetrieveConfigEntity, EasyUIBasedAppConfig, ExternalDataVariableEntity, - FileExtraConfig, ModelConfigEntity, PromptTemplateEntity, VariableEntity, @@ -14,6 +13,7 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.file.file_obj import FileExtraConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -32,11 +32,9 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, - account: Account, - name: str, - icon: str, - icon_background: str) -> App: + def convert_to_workflow( + self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str + ): """ Convert app to workflow @@ -50,22 +48,24 @@ def convert_to_workflow(self, app_model: App, :param account: Account :param name: new app name :param icon: new app icon + :param icon_type: new app icon type :param icon_background: new app icon background :return: new App instance """ # convert app model config + if not app_model.app_model_config: + raise ValueError("App model config is required") + workflow = self.convert_app_model_config_to_workflow( - app_model=app_model, - app_model_config=app_model.app_model_config, - account_id=account.id + app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id ) # create new app new_app = App() new_app.tenant_id = app_model.tenant_id - new_app.name = name if name else app_model.name + '(workflow)' - new_app.mode = AppMode.ADVANCED_CHAT.value \ - if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.name = name if name else app_model.name + "(workflow)" + new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.icon_type = icon_type if icon_type else app_model.icon_type new_app.icon = icon if icon else app_model.icon new_app.icon_background = icon_background if icon_background else app_model.icon_background new_app.enable_site = app_model.enable_site @@ -74,6 +74,8 @@ def convert_to_workflow(self, app_model: App, new_app.api_rph = app_model.api_rph new_app.is_demo = False new_app.is_public = app_model.is_public + new_app.created_by = account.id + new_app.updated_by = account.id db.session.add(new_app) db.session.flush() db.session.commit() @@ -85,30 +87,21 @@ def convert_to_workflow(self, app_model: App, return new_app - def convert_app_model_config_to_workflow(self, app_model: App, - app_model_config: AppModelConfig, - account_id: str) -> Workflow: + def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str): """ Convert app model config to workflow mode :param app_model: App instance :param app_model_config: AppModelConfig instance :param account_id: Account ID - :return: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) # convert app model config - app_config = self._convert_to_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = { - "nodes": [], - "edges": [] - } + graph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -120,11 +113,9 @@ def convert_app_model_config_to_workflow(self, app_model: App, # - show_retrieve_source -> knowledge-retrieval # convert to start node - start_node = self._convert_to_start_node( - variables=app_config.variables - ) + start_node = self._convert_to_start_node(variables=app_config.variables) - graph['nodes'].append(start_node) + graph["nodes"].append(start_node) # convert to http request node external_data_variable_node_mapping = {} @@ -132,7 +123,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, variables=app_config.variables, - external_data_variables=app_config.external_data_variables + external_data_variables=app_config.external_data_variables, ) for http_request_node in http_request_nodes: @@ -141,9 +132,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, # convert to knowledge retrieval node if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=app_config.dataset, - model_config=app_config.model + new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model ) if knowledge_retrieval_node: @@ -157,7 +146,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, model_config=app_config.model, prompt_template=app_config.prompt_template, file_upload=app_config.additional_features.file_upload, - external_data_variable_node_mapping=external_data_variable_node_mapping + external_data_variable_node_mapping=external_data_variable_node_mapping, ) graph = self._append_node(graph, llm_node) @@ -196,11 +185,12 @@ def convert_app_model_config_to_workflow(self, app_model: App, tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(new_app_mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account_id, environment_variables=[], + conversation_variables=[], ) db.session.add(workflow) @@ -208,24 +198,18 @@ def convert_app_model_config_to_workflow(self, app_model: App, return workflow - def _convert_to_app_config(self, app_model: App, - app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: app_mode = AppMode.value_of(app_model.mode) if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) elif app_mode == AppMode.CHAT: - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) elif app_mode == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) else: raise ValueError("Invalid app mode") @@ -244,14 +228,13 @@ def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: "data": { "title": "START", "type": NodeType.START.value, - "variables": [jsonable_encoder(v) for v in variables] - } + "variables": [jsonable_encoder(v) for v in variables], + }, } - def _convert_to_http_request_node(self, app_model: App, - variables: list[VariableEntity], - external_data_variables: list[ExternalDataVariableEntity]) \ - -> tuple[list[dict], dict[str, str]]: + def _convert_to_http_request_node( + self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity] + ) -> tuple[list[dict], dict[str, str]]: """ Convert API Based Extension to HTTP Request Node :param app_model: App instance @@ -273,40 +256,33 @@ def _convert_to_http_request_node(self, app_model: App, # get params from config api_based_extension_id = tool_config.get("api_based_extension_id") + if not api_based_extension_id: + continue # get api_based_extension api_based_extension = self._get_api_based_extension( - tenant_id=tenant_id, - api_based_extension_id=api_based_extension_id + tenant_id=tenant_id, api_based_extension_id=api_based_extension_id ) - if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(tool_variable)) - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key) inputs = {} for v in variables: - inputs[v.variable] = '{{#start.' + v.variable + '#}}' + inputs[v.variable] = "{{#start." + v.variable + "#}}" request_body = { - 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, - 'params': { - 'app_id': app_model.id, - 'tool_variable': tool_variable, - 'inputs': inputs, - 'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else '' - } + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "params": { + "app_id": app_model.id, + "tool_variable": tool_variable, + "inputs": inputs, + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + }, } request_body_json = json.dumps(request_body) - request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}') + request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}") http_request_node = { "id": f"http_request_{index}", @@ -316,20 +292,11 @@ def _convert_to_http_request_node(self, app_model: App, "type": NodeType.HTTP_REQUEST.value, "method": "post", "url": api_based_extension.api_endpoint, - "authorization": { - "type": "api-key", - "config": { - "type": "bearer", - "api_key": api_key - } - }, + "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, "headers": "", "params": "", - "body": { - "type": "json", - "data": request_body_json - } - } + "body": {"type": "json", "data": request_body_json}, + }, } nodes.append(http_request_node) @@ -341,32 +308,24 @@ def _convert_to_http_request_node(self, app_model: App, "data": { "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, - "variables": [{ - "variable": "response_json", - "value_selector": [http_request_node['id'], "body"] - }], + "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" - "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", - "outputs": { - "result": { - "type": "string" - } - } - } + 'response_json)\n return {\n "result": response_body["result"]\n }', + "outputs": {"result": {"type": "string"}}, + }, } nodes.append(code_node) - external_data_variable_node_mapping[external_data_variable.variable] = code_node['id'] + external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"] index += 1 return nodes, external_data_variable_node_mapping - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, - dataset_config: DatasetEntity, - model_config: ModelConfigEntity) \ - -> Optional[dict]: + def _convert_to_knowledge_retrieval_node( + self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity + ) -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -400,7 +359,7 @@ def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, "completion_params": { **model_config.parameters, "stop": model_config.stop, - } + }, } } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE @@ -408,20 +367,23 @@ def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, "multiple_retrieval_config": { "top_k": retrieve_config.top_k, "score_threshold": retrieve_config.score_threshold, - "reranking_model": retrieve_config.reranking_model + "reranking_model": retrieve_config.reranking_model, } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE else None, - } + }, } - def _convert_to_llm_node(self, original_app_mode: AppMode, - new_app_mode: AppMode, - graph: dict, - model_config: ModelConfigEntity, - prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, - external_data_variable_node_mapping: dict[str, str] = None) -> dict: + def _convert_to_llm_node( + self, + original_app_mode: AppMode, + new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileExtraConfig] = None, + external_data_variable_node_mapping: dict[str, str] | None = None, + ) -> dict: """ Convert to LLM Node :param original_app_mode: original app mode @@ -433,17 +395,18 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) - knowledge_retrieval_node = next(filter( - lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, - graph['nodes'] - ), None) + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + knowledge_retrieval_node = next( + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + ) role_prefix = None # Chat Model if model_config.mode == LLMMode.CHAT.value: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -452,45 +415,35 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template if not template: prompts = [] else: template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts = [ - { - "role": 'user', - "text": template - } - ] + prompts = [{"role": "user", "text": template}] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template prompts = [] - for m in advanced_chat_prompt_template.messages: - if advanced_chat_prompt_template: + if advanced_chat_prompt_template: + for m in advanced_chat_prompt_template.messages: text = m.text text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + text, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts.append({ - "role": m.role.value, - "text": text - }) + prompts.append({"role": m.role.value, "text": text}) # Completion Model else: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -499,57 +452,50 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template=template, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) - prompts = { - "text": template - } + prompts = {"text": template} - prompt_rules = prompt_template_config['prompt_rules'] + prompt_rules = prompt_template_config["prompt_rules"] role_prefix = { - "user": prompt_rules.get('human_prefix', 'Human'), - "assistant": prompt_rules.get('assistant_prefix', 'Assistant') + "user": prompt_rules.get("human_prefix", "Human"), + "assistant": prompt_rules.get("assistant_prefix", "Assistant"), } else: advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template if advanced_completion_prompt_template: text = advanced_completion_prompt_template.prompt text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + template=text, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) else: text = "" - text = text.replace('{{#query#}}', '{{#sys.query#}}') + text = text.replace("{{#query#}}", "{{#sys.query#}}") prompts = { "text": text, } - if advanced_completion_prompt_template.role_prefix: + if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix: role_prefix = { "user": advanced_completion_prompt_template.role_prefix.user, - "assistant": advanced_completion_prompt_template.role_prefix.assistant + "assistant": advanced_completion_prompt_template.role_prefix.assistant, } memory = None if new_app_mode == AppMode.ADVANCED_CHAT: - memory = { - "role_prefix": role_prefix, - "window": { - "enabled": False - } - } + memory = {"role_prefix": role_prefix, "window": {"enabled": False}} completion_params = model_config.parameters completion_params.update({"stop": model_config.stop}) @@ -563,41 +509,42 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, "provider": model_config.provider, "name": model_config.model, "mode": model_config.mode, - "completion_params": completion_params + "completion_params": completion_params, }, "prompt_template": prompts, "memory": memory, "context": { "enabled": knowledge_retrieval_node is not None, "variable_selector": ["knowledge_retrieval", "result"] - if knowledge_retrieval_node is not None else None + if knowledge_retrieval_node is not None + else None, }, "vision": { "enabled": file_upload is not None, "variable_selector": ["sys", "files"] if file_upload is not None else None, - "configs": { - "detail": file_upload.image_config['detail'] - } if file_upload is not None else None - } - } + "configs": {"detail": file_upload.image_config["detail"]} + if file_upload is not None and file_upload.image_config is not None + else None, + }, + }, } - def _replace_template_variables(self, template: str, - variables: list[dict], - external_data_variable_node_mapping: dict[str, str] = None) -> str: + def _replace_template_variables( + self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None + ) -> str: """ Replace Template Variables :param template: template :param variables: list of variables + :param external_data_variable_node_mapping: external data variable node mapping :return: """ for v in variables: - template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}") if external_data_variable_node_mapping: for variable, code_node_id in external_data_variable_node_mapping.items(): - template = template.replace('{{' + variable + '}}', - '{{#' + code_node_id + '.result#}}') + template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}") return template @@ -613,11 +560,8 @@ def _convert_to_end_node(self) -> dict: "data": { "title": "END", "type": NodeType.END.value, - "outputs": [{ - "variable": "result", - "value_selector": ["llm", "text"] - }] - } + "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], + }, } def _convert_to_answer_node(self) -> dict: @@ -629,11 +573,7 @@ def _convert_to_answer_node(self) -> dict: return { "id": "answer", "position": None, - "data": { - "title": "ANSWER", - "type": NodeType.ANSWER.value, - "answer": "{{#llm.text#}}" - } + "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str) -> dict: @@ -643,11 +583,7 @@ def _create_edge(self, source: str, target: str) -> dict: :param target: target node id :return: """ - return { - "id": f"{source}-{target}", - "source": source, - "target": target - } + return {"id": f"{source}-{target}", "source": source, "target": target} def _append_node(self, graph: dict, node: dict) -> dict: """ @@ -657,9 +593,9 @@ def _append_node(self, graph: dict, node: dict) -> dict: :param node: Node to append :return: """ - previous_node = graph['nodes'][-1] - graph['nodes'].append(node) - graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + previous_node = graph["nodes"][-1] + graph["nodes"].append(node) + graph["edges"].append(self._create_edge(previous_node["id"], node["id"])) return graph def _get_new_app_mode(self, app_model: App) -> AppMode: @@ -673,14 +609,20 @@ def _get_new_app_mode(self, app_model: App) -> AppMode: else: return AppMode.ADVANCED_CHAT - def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str): """ Get API Based Extension :param tenant_id: tenant id :param api_based_extension_id: api based extension id :return: """ - return db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}") + + return api_based_extension diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 047678837509e2..b4f0882a3a5971 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -1,3 +1,5 @@ +import uuid + from flask_sqlalchemy.pagination import Pagination from sqlalchemy import and_, or_ @@ -8,7 +10,6 @@ class WorkflowAppService: - def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: """ Get paginate workflow app logs @@ -16,47 +17,51 @@ def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Paginati :param args: request args :return: """ - query = ( - db.select(WorkflowAppLog) - .where( - WorkflowAppLog.tenant_id == app_model.tenant_id, - WorkflowAppLog.app_id == app_model.id - ) + query = db.select(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None - if args['keyword'] or status: - query = query.join( - WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id - ) + status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None + keyword = args["keyword"] + if keyword or status: + query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) - if args['keyword']: - keyword_val = f"%{args['keyword'][:30]}%" + if keyword: + keyword_like_val = f"%{args['keyword'][:30]}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_val), - WorkflowRun.outputs.ilike(keyword_val), + WorkflowRun.inputs.ilike(keyword_like_val), + WorkflowRun.outputs.ilike(keyword_like_val), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_val)) + and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), ] + # filter keyword by workflow run id + keyword_uuid = self._safe_parse_uuid(keyword) + if keyword_uuid: + keyword_conditions.append(WorkflowRun.id == keyword_uuid) + query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), ).filter(or_(*keyword_conditions)) if status: # join with workflow_run and filter by status - query = query.filter( - WorkflowRun.status == status.value - ) + query = query.filter(WorkflowRun.status == status.value) query = query.order_by(WorkflowAppLog.created_at.desc()) - pagination = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return pagination + + @staticmethod + def _safe_parse_uuid(value: str): + # fast check + if len(value) < 32: + return None + + try: + return uuid.UUID(value) + except ValueError: + return None diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index ccce38ada05fcd..b7b3abeaa2b8fe 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -18,6 +18,7 @@ def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) - :param app_model: app model :param args: request args """ + class WorkflowWithMessage: message_id: str conversation_id: str @@ -33,9 +34,7 @@ def __getattr__(self, item): with_message_workflow_runs = [] for workflow_run in pagination.data: message = workflow_run.message - with_message_workflow_run = WorkflowWithMessage( - workflow_run=workflow_run - ) + with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) if message: with_message_workflow_run.message_id = message.id with_message_workflow_run.conversation_id = message.conversation_id @@ -53,26 +52,30 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro :param app_model: app model :param args: request args """ - limit = int(args.get('limit', 20)) + limit = int(args.get("limit", 20)) base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == app_model.tenant_id, WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, ) - if args.get('last_id'): + if args.get("last_id"): last_workflow_run = base_query.filter( - WorkflowRun.id == args.get('last_id'), + WorkflowRun.id == args.get("last_id"), ).first() if not last_workflow_run: - raise ValueError('Last workflow run not exists') - - workflow_runs = base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, - WorkflowRun.id != last_workflow_run.id - ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + raise ValueError("Last workflow run not exists") + + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) else: workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() @@ -81,17 +84,13 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id + WorkflowRun.id != current_page_first_workflow_run.id, ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=workflow_runs, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: """ @@ -100,11 +99,15 @@ def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: :param app_model: app model :param run_id: workflow run id """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ).first() + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ) + .first() + ) return workflow_run @@ -117,12 +120,17 @@ def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[ if not workflow_run: return [] - node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ).order_by(WorkflowNodeExecution.index.desc()).all() + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ) + .order_by(WorkflowNodeExecution.index.desc()) + .all() + ) return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index d868255f96e419..4c3ded14ade386 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -37,11 +37,13 @@ def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: Get draft workflow """ # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + ) + .first() + ) # return draft workflow return workflow @@ -55,11 +57,15 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: return None # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + .first() + ) return workflow @@ -72,6 +78,7 @@ def sync_draft_workflow( unique_hash: Optional[str], account: Account, environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], ) -> Workflow: """ Sync draft workflow @@ -84,10 +91,7 @@ def sync_draft_workflow( raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure( - app_model=app_model, - features=features - ) + self.validate_features_structure(app_model=app_model, features=features) # create draft workflow if not found if not workflow: @@ -95,11 +99,12 @@ def sync_draft_workflow( tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, - environment_variables=environment_variables + environment_variables=environment_variables, + conversation_variables=conversation_variables, ) db.session.add(workflow) # update draft workflow if found @@ -109,6 +114,7 @@ def sync_draft_workflow( workflow.updated_by = account.id workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables # commit db session changes db.session.commit() @@ -119,9 +125,7 @@ def sync_draft_workflow( # return draft workflow return workflow - def publish_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: """ Publish workflow from draft @@ -134,7 +138,7 @@ def publish_workflow(self, app_model: App, draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('No valid workflow found.') + raise ValueError("No valid workflow found.") # create new workflow workflow = Workflow( @@ -145,7 +149,8 @@ def publish_workflow(self, app_model: App, graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, ) # commit db session changes @@ -183,17 +188,16 @@ def get_default_block_config(self, node_type: str, filters: Optional[dict] = Non workflow_engine_manager = WorkflowEngineManager() return workflow_engine_manager.get_default_config(node_type, filters) - def run_draft_workflow_node(self, app_model: App, - node_id: str, - user_inputs: dict, - account: Account) -> WorkflowNodeExecution: + def run_draft_workflow_node( + self, app_model: App, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: """ Run draft workflow node """ # fetch draft workflow by app_model draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") # run draft workflow node workflow_engine_manager = WorkflowEngineManager() @@ -222,7 +226,7 @@ def run_draft_workflow_node(self, app_model: App, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(workflow_node_execution) db.session.commit() @@ -243,14 +247,15 @@ def run_draft_workflow_node(self, app_model: App, inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), + execution_metadata=( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ), status=WorkflowNodeExecutionStatus.SUCCEEDED.value, elapsed_time=time.perf_counter() - start_at, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) else: # create workflow node execution @@ -269,7 +274,7 @@ def run_draft_workflow_node(self, app_model: App, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(workflow_node_execution) @@ -291,15 +296,16 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A workflow_converter = WorkflowConverter() if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: - raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get('name'), - icon=args.get('icon'), - icon_background=args.get('icon_background'), + name=args.get("name"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), ) return new_app @@ -307,15 +313,33 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") + + @classmethod + def get_elapsed_time(cls, workflow_run_id: str) -> float: + """ + Get elapsed time + """ + elapsed_time = 0.0 + + # fetch workflow node execution by workflow_run_id + workflow_nodes = ( + db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id) + .order_by(WorkflowNodeExecution.created_at.asc()) + .all() + ) + if not workflow_nodes: + return elapsed_time + + for node in workflow_nodes: + elapsed_time += node.elapsed_time + + return elapsed_time diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 2bcbe5c6f6318a..8fcb12b1cb9664 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,3 @@ - from flask_login import current_user from configs import dify_config @@ -14,34 +13,40 @@ def get_tenant_info(cls, tenant: Tenant): if not tenant: return None tenant_info = { - 'id': tenant.id, - 'name': tenant.name, - 'plan': tenant.plan, - 'status': tenant.status, - 'created_at': tenant.created_at, - 'in_trail': True, - 'trial_end_reason': None, - 'role': 'normal', + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at, + "in_trail": True, + "trial_end_reason": None, + "role": "normal", } # Get role of user - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == current_user.id - ).first() - tenant_info['role'] = tenant_account_join.role - - can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo - - if can_replace_logo and TenantService.has_roles(tenant, - [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + tenant_info["role"] = tenant_account_join.role + + can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo + + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] + ): base_url = dify_config.FILES_URL - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - - tenant_info['custom_config'] = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + + tenant_info["custom_config"] = { + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } return tenant_info diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index e0a1b219095b92..b50876cc794c55 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -14,7 +14,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index @@ -22,24 +22,25 @@ def add_document_to_index_task(dataset_document_id: str): Usage: add_document_to_index.delay(document_id) """ - logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) + logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() if not dataset_document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if dataset_document.indexing_status != 'completed': + if dataset_document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) + indexing_cache_key = "document_{}_indexing".format(dataset_document.id) try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) \ - .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) documents = [] for segment in segments: @@ -50,7 +51,7 @@ def add_document_to_index_task(dataset_document_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -58,7 +59,7 @@ def add_document_to_index_task(dataset_document_id: str): dataset = dataset_document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -66,12 +67,15 @@ def add_document_to_index_task(dataset_document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) + click.style( + "Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green" + ) + ) except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - dataset_document.status = 'error' + dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() finally: diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index b3aa8b596c8a3c..25c55bcfafe11c 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -10,9 +10,10 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def add_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Add annotation to index. :param annotation_id: annotation id @@ -23,38 +24,34 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create([document], duplicate_check=True) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 063ffca6fdd997..fa7e5ac9190f3c 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -14,84 +14,77 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, - user_id: str): +@shared_task(queue="dataset") +def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str): """ Add annotation to index. :param job_id: job_id :param content_list: content list - :param tenant_id: tenant id :param app_id: app id + :param tenant_id: tenant id :param user_id: user_id """ - logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if app: try: documents = [] for content in content_list: annotation = MessageAnnotation( - app_id=app.id, - content=content['answer'], - question=content['question'], - account_id=user_id + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id ) db.session.add(annotation) db.session.flush() document = Document( - page_content=content['question'], - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, - 'annotation' + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) ) if not dataset_collection_binding: raise NotFound("App annotation setting not found") dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create(documents, duplicate_check=True) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at), - fg='green')) + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", + ) + ) except Exception as e: db.session.rollback() - redis_client.setex(indexing_cache_key, 600, 'error') - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) redis_client.setex(indexing_error_msg_key, 600, str(e)) logging.exception("Build index for batch import annotations failed") diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 81155a35e42b0e..5758db53de820b 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -9,36 +9,33 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str): """ Async delete annotation index task """ - logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=dataset_collection_binding.id + indexing_technique="high_quality", + collection_binding_id=dataset_collection_binding.id, ) try: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation deleted index failed:{}".format(str(e))) - diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index e5e03c9b51758b..0f83dfdbd4a72f 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -12,49 +12,44 @@ from models.model import App, AppAnnotationSetting, MessageAnnotation -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): """ Async enable annotation reply task """ - logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() if not app: raise NotFound("App not found") - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if not app_annotation_setting: raise NotFound("App annotation setting not found") - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) try: - dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=app_annotation_setting.collection_binding_id + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, ) try: if annotations_count > 0: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('app_id', app_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("app_id", app_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") - redis_client.setex(disable_app_annotation_job_key, 600, 'completed') + redis_client.setex(disable_app_annotation_job_key, 600, "completed") # delete annotation setting db.session.delete(app_annotation_setting) @@ -62,12 +57,12 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch deleted index failed:{}".format(str(e))) - redis_client.setex(disable_app_annotation_job_key, 600, 'error') - disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(disable_app_annotation_error_key, 600, str(e)) finally: redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index fda8b7a250f128..82b70f6b71eddf 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -15,37 +15,39 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float, - embedding_provider_name: str, embedding_model_name: str): +@shared_task(queue="dataset") +def enable_annotation_reply_task( + job_id: str, + app_id: str, + user_id: str, + tenant_id: str, + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, +): """ Async enable annotation reply task """ - logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: raise NotFound("App not found") annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" + ) + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() if annotation_setting: annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id @@ -58,48 +60,42 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ score_threshold=score_threshold, collection_binding_id=dataset_collection_binding.id, created_user_id=user_id, - updated_user_id=user_id + updated_user_id=user_id, ) db.session.add(new_app_annotation_setting) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) try: - vector.delete_by_metadata_field('app_id', app_id) + vector.delete_by_metadata_field("app_id", app_id) except Exception as e: - logging.info( - click.style('Delete annotation index error: {}'.format(str(e)), - fg='red')) + logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) vector.create(documents) db.session.commit() - redis_client.setex(enable_app_annotation_job_key, 600, 'completed') + redis_client.setex(enable_app_annotation_job_key, 600, "completed") end_at = time.perf_counter() logging.info( - click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch created index failed:{}".format(str(e))) - redis_client.setex(enable_app_annotation_job_key, 600, 'error') - enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(enable_app_annotation_error_key, 600, str(e)) db.session.rollback() finally: diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 7219abd3cdf6fc..b685d84d07ad28 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -10,9 +10,10 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def update_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Update annotation to index. :param annotation_id: annotation id @@ -23,39 +24,35 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) vector.add_texts([document]) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 67cc03bdebee54..de7f0ddec1f3b6 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -16,9 +16,10 @@ from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') -def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: str, document_id: str, - tenant_id: str, user_id: str): +@shared_task(queue="dataset") +def batch_create_segment_to_index_task( + job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str +): """ Async batch create segment to index :param job_id: @@ -30,44 +31,44 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s Usage: batch_create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset not exist.') + raise ValueError("Dataset not exist.") dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - raise ValueError('Document not exist.') + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - raise ValueError('Document is not available.') + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + raise ValueError("Document is not available.") document_segments = [] embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) for segment in content: - content = segment['content'] + content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) if embedding_model else 0 - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == dataset_document.id - ).scalar() + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == dataset_document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=tenant_id, dataset_id=dataset_id, @@ -80,20 +81,22 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s tokens=tokens, created_by=user_id, indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - status='completed', - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + status="completed", + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) - if dataset_document.doc_form == 'qa_model': - segment_document.answer = segment['answer'] + if dataset_document.doc_form == "qa_model": + segment_document.answer = segment["answer"] db.session.add(segment_document) document_segments.append(segment_document) # add index to db indexing_runner = IndexingRunner() indexing_runner.batch_add_segments(document_segments, dataset) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() - logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Segments batch created index failed:{}".format(str(e))) - redis_client.setex(indexing_cache_key, 600, 'error') + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 1f26c966c46af5..36249038011747 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -19,9 +19,15 @@ # Add import statement for ValueError -@shared_task(queue='dataset') -def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, - index_struct: str, collection_binding_id: str, doc_form: str): +@shared_task(queue="dataset") +def clean_dataset_task( + dataset_id: str, + tenant_id: str, + indexing_technique: str, + index_struct: str, + collection_binding_id: str, + doc_form: str, +): """ Clean dataset when dataset deleted. :param dataset_id: dataset id @@ -33,7 +39,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start clean dataset when dataset deleted: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: @@ -48,9 +54,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() if documents is None or len(documents) == 0: - logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) else: - logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green")) # Specify the index type before initializing the index processor if doc_form is None: raise ValueError("Index type must be specified.") @@ -71,15 +77,16 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, if documents: for document in documents: try: - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) if not file: continue storage.delete(file.key) @@ -90,6 +97,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned dataset when dataset deleted: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style( + "Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Cleaned dataset when dataset deleted failed") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 0fd05615b65430..ae2855aa2ebc4d 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -12,7 +12,7 @@ from models.model import UploadFile -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): """ Clean document when document deleted. @@ -23,14 +23,14 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i Usage: clean_document_task.delay(document_id, dataset_id) """ - logging.info(click.style('Start clean document when document deleted: {}'.format(document_id), fg='green')) + logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green")) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() # check segment is exist @@ -44,9 +44,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() if file_id: - file = db.session.query(UploadFile).filter( - UploadFile.id == file_id - ).first() + file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if file: try: storage.delete(file.key) @@ -57,6 +55,10 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 9b697b63511a12..75d9e031306381 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -9,7 +9,7 @@ from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_notion_document_task(document_ids: list[str], dataset_id: str): """ Clean document when document deleted. @@ -18,20 +18,20 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): Usage: clean_notion_document_task.delay(document_ids, dataset_id) """ - logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green')) + logging.info( + click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green") + ) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() db.session.delete(document) segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() @@ -44,8 +44,12 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format( - dataset_id, end_at - start_at), - fg='green')) + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index d31286e4cc4c42..26375743b68e25 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -14,7 +14,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): """ Async create segment to index @@ -22,23 +22,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] :param keywords: Usage: create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start create segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'waiting': + if segment.status != "waiting": return - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -49,23 +49,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset.doc_form @@ -75,18 +75,20 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() end_at = time.perf_counter() - logging.info(click.style('Segment created to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("create segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index c1b0e7f1a4cf91..cfc54920e23caa 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -11,7 +11,7 @@ from models.dataset import Document as DatasetDocument -@shared_task(queue='dataset') +@shared_task(queue="dataset") def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index @@ -19,87 +19,132 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): :param action: action Usage: deal_dataset_vector_index_task.delay(dataset_id, action) """ - logging.info(click.style('Start deal dataset vector index: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) if dataset_documents: - documents = [] + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + for dataset_document in dataset_documents: - # delete from vector index - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) .order_by(DocumentSegment.position.asc()).all() - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } + try: + # add from vector index + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - documents.append(document) - - # save vector index - index_processor.load(dataset, documents, with_keywords=False) - elif action == 'update': - # clean index - index_processor.clean(dataset, None, with_keywords=False) - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) # add new index if dataset_documents: - documents = [] + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False) + for dataset_document in dataset_documents: - # delete from vector index - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ).order_by(DocumentSegment.position.asc()).all() - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } + # update from vector index + try: + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - documents.append(document) - - # save vector index - index_processor.load(dataset, documents, with_keywords=False) + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index d79286cf3d8a3c..c3e0ea5d9fbb77 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -10,7 +10,7 @@ from models.dataset import Dataset, Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): """ Async Remove segment from index @@ -21,22 +21,22 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ Usage: delete_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form @@ -44,7 +44,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ index_processor.clean(dataset, [index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("delete segment from index failed") finally: diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 4788bf4e4b5960..15e1e50076e8c9 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -11,7 +11,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_segment_from_index_task(segment_id: str): """ Async disable segment from index @@ -19,33 +19,33 @@ def disable_segment_from_index_task(segment_id: str): Usage: disable_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed , disable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed , disable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset_document.doc_form @@ -53,7 +53,9 @@ def disable_segment_from_index_task(segment_id: str): index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception: logging.exception("remove segment from index failed") segment.enabled = True diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4cced36ecdd856..9ea4c99649cfe3 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -14,7 +14,7 @@ from models.source import DataSourceOauthBinding -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_sync_task(dataset_id: str, document_id: str): """ Async update document @@ -23,50 +23,50 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): Usage: document_indexing_sync_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start sync document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") data_source_info = document.data_source_info_dict - if document.data_source_type == 'notion_import': - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): raise ValueError("no notion page found") - workspace_id = data_source_info['notion_workspace_id'] - page_id = data_source_info['notion_page_id'] - page_type = data_source_info['type'] - page_edited_time = data_source_info['last_edited_time'] + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') + raise ValueError("Data source binding not found.") loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=document.tenant_id + tenant_id=document.tenant_id, ) last_edited_time = loader.get_notion_last_edited_time() # check the page is updated if last_edited_time != page_edited_time: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -74,7 +74,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -89,7 +89,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -97,8 +103,10 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info( + click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index cc93a1341e180d..e0da5f9ed032c7 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -12,7 +12,7 @@ from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -36,16 +36,17 @@ def document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) @@ -53,15 +54,14 @@ def document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -71,8 +71,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index f129d93de8da5a..6e681bcf4f614a 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -12,7 +12,7 @@ from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_update_task(dataset_id: str, document_id: str): """ Async update document @@ -21,18 +21,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str): Usage: document_indexing_update_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start update document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start update document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -40,7 +37,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -57,7 +54,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -65,8 +68,8 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 884e222d1b3eef..0a7568c385ba98 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -13,7 +13,7 @@ from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def duplicate_document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -37,16 +37,17 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -54,12 +55,11 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: # clean old data @@ -77,7 +77,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() documents.append(document) db.session.add(document) @@ -87,8 +87,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index e37c06855d47de..1412ad9ec74c06 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -13,7 +13,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def enable_segment_to_index_task(segment_id: str): """ Async enable segment to index @@ -21,17 +21,17 @@ def enable_segment_to_index_task(segment_id: str): Usage: enable_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: document = Document( @@ -41,23 +41,23 @@ def enable_segment_to_index_task(segment_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() @@ -65,12 +65,14 @@ def enable_segment_to_index_task(segment_id: str): index_processor.load(dataset, [document]) end_at = time.perf_counter() - logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index a46eafa797907c..c7dfb9bf6063ff 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -9,7 +9,7 @@ from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): """ Async Send invite member mail @@ -19,36 +19,43 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam :param inviter_name :param workspace_name - Usage: send_invite_member_mail_task.delay(langauge, to, token, inviter_name, workspace_name) + Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) """ if not mail.is_inited(): return - logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name), - fg='green')) + logging.info( + click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") + ) start_at = time.perf_counter() # send invite member mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/activate?token={token}' - if language == 'zh-Hans': - html_content = render_template('invite_member_mail_template_zh-CN.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" + if language == "zh-Hans": + html_content = render_template( + "invite_member_mail_template_zh-CN.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: - html_content = render_template('invite_member_mail_template_en-US.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + html_content = render_template( + "invite_member_mail_template_en-US.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: - logging.exception("Send invite member mail to {} failed".format(to)) \ No newline at end of file + logging.exception("Send invite member mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 4e1b8a89135e7f..cbb78976cace4e 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -9,7 +9,7 @@ from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_reset_password_mail_task(language: str, to: str, token: str): """ Async Send reset password mail @@ -20,26 +20,24 @@ def send_reset_password_mail_task(language: str, to: str, token: str): if not mail.is_inited(): return - logging.info(click.style('Start password reset mail to {}'.format(to), fg='green')) + logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) start_at = time.perf_counter() # send reset password mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}' - if language == 'zh-Hans': - html_content = render_template('reset_password_mail_template_zh-CN.html', - to=to, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}" + if language == "zh-Hans": + html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, url=url) mail.send(to=to, subject="重置您的 Dify 密码", html=html_content) else: - html_content = render_template('reset_password_mail_template_en-US.html', - to=to, - url=url) + html_content = render_template("reset_password_mail_template_en-US.html", to=to, url=url) mail.send(to=to, subject="Reset Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send password reset mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Send password reset mail to {} failed".format(to)) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 6b4cab55b399ae..260069c6e2cd4e 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -10,7 +10,7 @@ from models.workflow import WorkflowRun -@shared_task(queue='ops_trace') +@shared_task(queue="ops_trace") def process_trace_tasks(tasks_data): """ Async process trace tasks @@ -20,17 +20,17 @@ def process_trace_tasks(tasks_data): """ from core.ops.ops_trace_manager import OpsTraceManager - trace_info = tasks_data.get('trace_info') - app_id = tasks_data.get('app_id') - trace_info_type = tasks_data.get('trace_info_type') + trace_info = tasks_data.get("trace_info") + app_id = tasks_data.get("app_id") + trace_info_type = tasks_data.get("trace_info_type") trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - if trace_info.get('message_data'): - trace_info['message_data'] = Message.from_dict(data=trace_info['message_data']) - if trace_info.get('workflow_data'): - trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data']) - if trace_info.get('documents'): - trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']] + if trace_info.get("message_data"): + trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"]) + if trace_info.get("workflow_data"): + trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) + if trace_info.get("documents"): + trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] try: if trace_instance: diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 02278f512b74b6..18bae14ffac549 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -10,7 +10,7 @@ from models.dataset import Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def recover_document_indexing_task(dataset_id: str, document_id: str): """ Async recover document @@ -19,16 +19,13 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): Usage: recover_document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style('Recover document: {}'.format(document_id), fg='green')) + logging.info(click.style("Recover document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") try: indexing_runner = IndexingRunner() @@ -39,8 +36,10 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): elif document.indexing_status == "indexing": indexing_runner.run_in_indexing_status(document) end_at = time.perf_counter() - logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info( + click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 378756e68c8202..66f78636ecca60 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -1,8 +1,10 @@ import logging import time +from collections.abc import Callable import click from celery import shared_task +from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError from extensions.ext_database import db @@ -28,12 +30,12 @@ ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun -@shared_task(queue='app_deletion', bind=True, max_retries=3) +@shared_task(queue="app_deletion", bind=True, max_retries=3) def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): - logging.info(click.style(f'Start deleting app and related data: {tenant_id}:{app_id}', fg='green')) + logging.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green")) start_at = time.perf_counter() try: # Delete related data @@ -54,15 +56,17 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_tag_bindings(tenant_id, app_id) _delete_end_users(tenant_id, app_id) _delete_trace_app_configs(tenant_id, app_id) + _delete_conversation_variables(app_id=app_id) end_at = time.perf_counter() - logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green')) + logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) except SQLAlchemyError as e: logging.exception( - click.style(f"Database error occurred while deleting app {app_id} and related data", fg='red')) + click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red") + ) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds except Exception as e: - logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg='red')) + logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red")) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds @@ -74,7 +78,7 @@ def del_model_config(model_config_id: str): """select id from app_model_configs where app_id=:app_id limit 1000""", {"app_id": app_id}, del_model_config, - "app model config" + "app model config", ) @@ -82,12 +86,7 @@ def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) - _delete_records( - """select id from sites where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_site, - "site" - ) + _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") def _delete_app_api_tokens(tenant_id: str, app_id: str): @@ -95,10 +94,7 @@ def del_api_token(api_token_id: str): db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( - """select id from api_tokens where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_api_token, - "api token" + """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" ) @@ -110,44 +106,47 @@ def del_installed_app(installed_app_id: str): """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_installed_app, - "installed app" + "installed app", ) def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(recommended_app_id: str): db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", {"app_id": app_id}, del_recommended_app, - "recommended app" + "recommended app", ) def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(annotation_hit_history_id: str): db.session.query(AppAnnotationHitHistory).filter( - AppAnnotationHitHistory.id == annotation_hit_history_id).delete(synchronize_session=False) + AppAnnotationHitHistory.id == annotation_hit_history_id + ).delete(synchronize_session=False) _delete_records( """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_hit_history, - "annotation hit history" + "annotation hit history", ) def del_annotation_setting(annotation_setting_id: str): db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from app_annotation_settings where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_setting, - "annotation setting" + "annotation setting", ) @@ -159,7 +158,7 @@ def del_dataset_join(dataset_join_id: str): """select id from app_dataset_joins where app_id=:app_id limit 1000""", {"app_id": app_id}, del_dataset_join, - "dataset join" + "dataset join", ) @@ -171,7 +170,7 @@ def del_workflow(workflow_id: str): """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow, - "workflow" + "workflow", ) @@ -183,82 +182,93 @@ def del_workflow_run(workflow_run_id: str): """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_run, - "workflow run" + "workflow run", ) def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == workflow_node_execution_id).delete(synchronize_session=False) + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_node_execution, - "workflow node execution" + "workflow node execution", ) def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_app_log, - "workflow app log" + "workflow app log", ) def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(conversation_id: str): db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", {"app_id": app_id}, del_conversation, - "conversation" + "conversation", ) +def _delete_conversation_variables(*, app_id: str): + stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) + with db.engine.connect() as conn: + conn.execute(stmt) + conn.commit() + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) + + def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False) - db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) + db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete( - synchronize_session=False) + db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) db.session.query(Message).filter(Message.id == message_id).delete() _delete_records( - """select id from messages where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_message, - "message" + """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" ) def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(tool_provider_id: str): db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tool_provider, - "tool workflow provider" + "tool workflow provider", ) @@ -270,7 +280,7 @@ def del_tag_binding(tag_binding_id: str): """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tag_binding, - "tag binding" + "tag binding", ) @@ -282,24 +292,25 @@ def del_end_user(end_user_id: str): """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_end_user, - "end user" + "end user", ) def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(trace_app_config_id: str): db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", {"app_id": app_id}, del_trace_app_config, - "trace app config" + "trace app config", ) -def _delete_records(query_sql: str, params: dict, delete_func: callable, name: str) -> None: +def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: rs = conn.execute(db.text(query_sql), params) @@ -311,7 +322,7 @@ def _delete_records(query_sql: str, params: dict, delete_func: callable, name: s try: delete_func(record_id) db.session.commit() - logging.info(click.style(f"Deleted {name} {record_id}", fg='green')) + logging.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: logging.exception(f"Error occurred while deleting {name} {record_id}") continue diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index cff8dddc53c630..1909eaf3418517 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -11,7 +11,7 @@ from models.dataset import Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def remove_document_from_index_task(document_id: str): """ Async Remove document from index @@ -19,23 +19,23 @@ def remove_document_from_index_task(document_id: str): Usage: remove_document_from_index.delay(document_id) """ - logging.info(click.style('Start remove document segments from index: {}'.format(document_id), fg='green')) + logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) start_at = time.perf_counter() document = db.session.query(Document).filter(Document.id == document_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if document.indexing_status != 'completed': + if document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) try: dataset = document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() @@ -49,7 +49,10 @@ def remove_document_from_index_task(document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document removed from index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + click.style( + "Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("remove document from index failed") if not document.archived: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 1114809b3019a1..73471fd6e77c9b 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -13,7 +13,7 @@ from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): """ Async process document @@ -27,22 +27,23 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() for document_id in document_ids: - retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id) + retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -50,11 +51,10 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): redis_client.delete(retry_indexing_cache_key) return - logging.info(click.style('Start retry document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) try: if document: # clean old data @@ -70,7 +70,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -79,13 +79,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(retry_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 320da8718a12cb..99fb66e1f3d1cd 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -13,7 +13,7 @@ from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ Async process document @@ -26,22 +26,23 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id) + sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -49,11 +50,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): redis_client.delete(sync_indexing_cache_key) return - logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() try: if document: # clean old data @@ -69,7 +67,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -78,13 +76,13 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(sync_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f29e5ef4d61285..2d52399d29a995 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -79,4 +79,7 @@ CODE_EXECUTION_API_KEY= VOLC_API_KEY= VOLC_SECRET_KEY= VOLC_MODEL_ENDPOINT_ID= -VOLC_EMBEDDING_ENDPOINT_ID= \ No newline at end of file +VOLC_EMBEDDING_ENDPOINT_ID= + +# 360 AI Credentials +ZHINAO_API_KEY= diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 3326f874b09505..79a3dc03941c9f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -22,23 +22,20 @@ ) from anthropic.types.message_delta_event import Delta -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockAnthropicClass: @staticmethod def mocked_anthropic_chat_create_sync(model: str) -> Message: return Message( - id='msg-123', - type='message', - role='assistant', - content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')], + id="msg-123", + type="message", + role="assistant", + content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")], model=model, - stop_reason='stop_sequence', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) + stop_reason="stop_sequence", + usage=Usage(input_tokens=1, output_tokens=1), ) @staticmethod @@ -46,52 +43,43 @@ def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent full_response_text = "hello, I'm a chatbot from anthropic" yield MessageStartEvent( - type='message_start', + type="message_start", message=Message( - id='msg-123', + id="msg-123", content=[], - role='assistant', + role="assistant", model=model, stop_reason=None, - type='message', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) - ) + type="message", + usage=Usage(input_tokens=1, output_tokens=1), + ), ) index = 0 for i in range(0, len(full_response_text)): yield ContentBlockDeltaEvent( - type='content_block_delta', - delta=TextDelta(text=full_response_text[i], type='text_delta'), - index=index + type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index ) index += 1 yield MessageDeltaEvent( - type='message_delta', - delta=Delta( - stop_reason='stop_sequence' - ), - usage=MessageDeltaUsage( - output_tokens=1 - ) + type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1) ) - yield MessageStopEvent(type='message_stop') - - def mocked_anthropic(self: Messages, *, - max_tokens: int, - messages: Iterable[MessageParam], - model: str, - stream: Literal[True], - **kwargs: Any - ) -> Union[Message, Stream[MessageStreamEvent]]: + yield MessageStopEvent(type="message_stop") + + def mocked_anthropic( + self: Messages, + *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any, + ) -> Union[Message, Stream[MessageStreamEvent]]: if len(self._client.api_key) < 18: - raise anthropic.AuthenticationError('Invalid API key') + raise anthropic.AuthenticationError("Invalid API key") if stream: return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model) @@ -102,7 +90,7 @@ def mocked_anthropic(self: Messages, *, @pytest.fixture def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic) + monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic) yield diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index d838e9890ff562..bc0684086f0b20 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -12,63 +12,46 @@ from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse -current_api_key = '' +current_api_key = "" + class MockGoogleResponseClass: _done = False def __iter__(self): - full_response_text = 'it\'s google!' + full_response_text = "it's google!" for i in range(0, len(full_response_text) + 1, 1): if i == len(full_response_text): self._done = True yield GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) else: yield GenerateContentResponse( - done=False, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) + class MockGoogleResponseCandidateClass: - finish_reason = 'stop' + finish_reason = "stop" @property def content(self) -> gag_content.Content: - return gag_content.Content( - parts=[ - gag_content.Part(text='it\'s google!') - ] - ) + return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) + class MockGoogleClass: @staticmethod def generate_content_sync() -> GenerateContentResponse: - return GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] - ) + return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: return MockGoogleResponseClass() - def generate_content(self: GenerativeModel, + def generate_content( + self: GenerativeModel, contents: content_types.ContentsType, *, generation_config: generation_config_types.GenerationConfigType | None = None, @@ -79,21 +62,21 @@ def generate_content(self: GenerativeModel, global current_api_key if len(current_api_key) < 16: - raise Exception('Invalid API key') + raise Exception("Invalid API key") if stream: return MockGoogleClass.generate_content_stream() - + return MockGoogleClass.generate_content_sync() - + @property def generative_response_text(self) -> str: - return 'it\'s google!' - + return "it's google!" + @property def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] - + def make_client(self: _ClientManager, name: str): global current_api_key @@ -121,7 +104,8 @@ def nop(self, *args, **kwargs): if not self.default_metadata: return client - + + @pytest.fixture def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) @@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): yield - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index a75b058d92c16a..97038ef5963e87 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -6,14 +6,15 @@ from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): if MOCK: monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation) - + yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 1607624c3c9056..9ee76c935c9873 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -22,10 +22,8 @@ def generate_create_sync(model: str) -> TextGenerationResponse: details=Details( finish_reason="length", generated_tokens=6, - tokens=[ - Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6) - ] - ) + tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)], + ), ) return response @@ -36,26 +34,23 @@ def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse for i in range(0, len(full_text)): response = TextGenerationStreamResponse( - token = Token(id=i, text=full_text[i], logprob=0.0, special=False), + token=Token(id=i, text=full_text[i], logprob=0.0, special=False), ) response.generated_text = full_text[i] - response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1) + response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1) yield response - def text_generation(self: InferenceClient, prompt: str, *, - stream: Literal[False] = ..., - model: Optional[str] = None, - **kwargs: Any + def text_generation( + self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]: # check if key is valid - if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']): - raise BadRequestError('Invalid API key') - + if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]): + raise BadRequestError("Invalid API key") + if model is None: - raise BadRequestError('Invalid model') - + raise BadRequestError("Invalid model") + if stream: return MockHuggingfaceChatClass.generate_create_stream(model) return MockHuggingfaceChatClass.generate_create_sync(model) - diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py new file mode 100644 index 00000000000000..b37b109ebae58c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -0,0 +1,93 @@ +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter + + +class MockTEIClass: + @staticmethod + def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + # During mock, we don't have a real server to query, so we just return a dummy value + if "rerank" in model_name: + model_type = "reranker" + else: + model_type = "embedding" + + return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) + + @staticmethod + def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + # Use space as token separator, and split the text into tokens + tokenized_texts = [] + for text in texts: + tokens = text.split(" ") + current_index = 0 + tokenized_text = [] + for idx, token in enumerate(tokens): + s_token = { + "id": idx, + "text": token, + "special": False, + "start": current_index, + "stop": current_index + len(token), + } + current_index += len(token) + 1 + tokenized_text.append(s_token) + tokenized_texts.append(tokenized_text) + return tokenized_texts + + @staticmethod + def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + # { + # "object": "list", + # "data": [ + # { + # "object": "embedding", + # "embedding": [...], + # "index": 0 + # } + # ], + # "model": "MODEL_NAME", + # "usage": { + # "prompt_tokens": 3, + # "total_tokens": 3 + # } + # } + embeddings = [] + for idx, text in enumerate(texts): + embedding = [0.1] * 768 + embeddings.append( + { + "object": "embedding", + "embedding": embedding, + "index": idx, + } + ) + return { + "object": "list", + "data": embeddings, + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": sum(len(text.split(" ")) for text in texts), + "total_tokens": sum(len(text.split(" ")) for text in texts), + }, + } + + def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: + # Example response: + # [ + # { + # "index": 0, + # "text": "Deep Learning is ...", + # "score": 0.9950755 + # } + # ] + reranked_docs = [] + for idx, text in enumerate(texts): + reranked_docs.append( + { + "index": idx, + "text": text, + "score": 0.9, + } + ) + # For mock, only return the first document + break + return reranked_docs diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index 0d3f0fbbeaab5f..6637f4f212a50e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: +def mock_openai( + monkeypatch: MonkeyPatch, + methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]], +) -> Callable[[], None]: """ - mock openai module + mock openai module - :param monkeypatch: pytest monkeypatch fixture - :return: unpatch function + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function """ + def unpatch() -> None: monkeypatch.undo() @@ -52,15 +56,16 @@ def unpatch() -> None: return unpatch -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_openai_mock(request, monkeypatch): - methods = request.param if hasattr(request, 'param') else [] + methods = request.param if hasattr(request, "param") else [] if MOCK: unpatch = mock_openai(monkeypatch, methods=methods) - + yield if MOCK: - unpatch() \ No newline at end of file + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index ba902e32ea6b22..d9cd7b046e001c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -43,62 +43,64 @@ def generate_function_call( if not functions or len(functions) == 0: return None function: completion_create_params.Function = functions[0] - function_name = function['name'] - function_description = function['description'] - function_parameters = function['parameters'] - function_parameters_type = function_parameters['type'] - if function_parameters_type != 'object': + function_name = function["name"] + function_description = function["description"] + function_parameters = function["parameters"] + function_parameters_type = function_parameters["type"] + if function_parameters_type != "object": return None - function_parameters_properties = function_parameters['properties'] - function_parameters_required = function_parameters['required'] + function_parameters_properties = function_parameters["properties"] + function_parameters_required = function_parameters["required"] parameters = {} for parameter_name, parameter in function_parameters_properties.items(): if parameter_name not in function_parameters_required: continue - parameter_type = parameter['type'] - if parameter_type == 'string': - if 'enum' in parameter: - if len(parameter['enum']) == 0: + parameter_type = parameter["type"] + if parameter_type == "string": + if "enum" in parameter: + if len(parameter["enum"]) == 0: continue - parameters[parameter_name] = parameter['enum'][0] + parameters[parameter_name] = parameter["enum"][0] else: - parameters[parameter_name] = 'kawaii' - elif parameter_type == 'integer': + parameters[parameter_name] = "kawaii" + elif parameter_type == "integer": parameters[parameter_name] = 114514 - elif parameter_type == 'number': + elif parameter_type == "number": parameters[parameter_name] = 1919810.0 - elif parameter_type == 'boolean': + elif parameter_type == "boolean": parameters[parameter_name] = True return FunctionCall(name=function_name, arguments=dumps(parameters)) - + @staticmethod - def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: + def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None tool = tools[0] - if 'type' in tools and tools['type'] != 'function': + if "type" in tools and tools["type"] != "function": return None - function = tool['function'] + function = tool["function"] function_call = MockChatClass.generate_function_call(functions=[function]) if function_call is None: return None - - list_tool_calls.append(ChatCompletionMessageToolCall( - id='sakurajima-mai', - function=Function( - name=function_call.name, - arguments=function_call.arguments, - ), - type='function' - )) + + list_tool_calls.append( + ChatCompletionMessageToolCall( + id="sakurajima-mai", + function=Function( + name=function_call.name, + arguments=function_call.arguments, + ), + type="function", + ) + ) return list_tool_calls - + @staticmethod def mocked_openai_chat_create_sync( model: str, @@ -111,30 +113,27 @@ def mocked_openai_chat_create_sync( tool_calls = MockChatClass.generate_tool_calls(tools=tools) return _ChatCompletion( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ _ChatCompletionChoice( - finish_reason='content_filter', + finish_reason="content_filter", index=0, message=ChatCompletionMessage( - content='elaina', - role='assistant', - function_call=function_call, - tool_calls=tool_calls - ) + content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls + ), ) ], created=int(time()), model=model, - object='chat.completion', - system_fingerprint='', + object="chat.completion", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod def mocked_openai_chat_create_stream( model: str, @@ -150,36 +149,40 @@ def mocked_openai_chat_create_stream( for i in range(0, len(full_text) + 1): if i == len(full_text): yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( - content='', + content="", function_call=ChoiceDeltaFunctionCall( name=function_call.name, arguments=function_call.arguments, - ) if function_call else None, - role='assistant', + ) + if function_call + else None, + role="assistant", tool_calls=[ ChoiceDeltaToolCall( index=0, - id='misaka-mikoto', + id="misaka-mikoto", function=ChoiceDeltaToolCallFunction( name=tool_calls[0].function.name, arguments=tool_calls[0].function.arguments, ), - type='function' + type="function", ) - ] if tool_calls and len(tool_calls) > 0 else None + ] + if tool_calls and len(tool_calls) > 0 + else None, ), - finish_reason='function_call', + finish_reason="function_call", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=17, @@ -188,30 +191,45 @@ def mocked_openai_chat_create_stream( ) else: yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( content=full_text[i], - role='assistant', + role="assistant", ), - finish_reason='content_filter', + finish_reason="content_filter", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", ) - def chat_create(self: Completions, *, + def chat_create( + self: Completions, + *, messages: list[ChatCompletionMessageParam], - model: Union[str,Literal[ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"], + model: Union[ + str, + Literal[ + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + ], ], functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, @@ -220,24 +238,32 @@ def chat_create(self: Completions, *, **kwargs: Any, ): openai_models = [ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", - ] - azure_openai_models = [ - "gpt35", "gpt-4v", "gpt-35-turbo" + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", ] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if stream: return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools) - - return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) \ No newline at end of file + + return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index b0d26759055bff..c27e89248f4403 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -17,9 +17,7 @@ class MockCompletionsClass: @staticmethod - def mocked_openai_completion_create_sync( - model: str - ) -> CompletionMessage: + def mocked_openai_completion_create_sync(model: str) -> CompletionMessage: return CompletionMessage( id="cmpl-3QJQa5jXJ5Z5X", object="text_completion", @@ -38,13 +36,11 @@ def mocked_openai_completion_create_sync( prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod - def mocked_openai_completion_create_stream( - model: str - ) -> Generator[CompletionMessage, None, None]: + def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]: full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```" for i in range(0, len(full_text) + 1): if i == len(full_text): @@ -76,46 +72,59 @@ def mocked_openai_completion_create_stream( model=model, system_fingerprint="", choices=[ - CompletionChoice( - text=full_text[i], - index=0, - logprobs=None, - finish_reason="content_filter" - ) + CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter") ], ) - def completion_create(self: Completions, *, model: Union[ - str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", - "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", - "text-ada-001"], + def completion_create( + self: Completions, + *, + model: Union[ + str, + Literal[ + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + ], ], prompt: Union[str, list[str], list[int], list[list[int]], None], stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ): openai_models = [ - "babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", - ] - azure_openai_models = [ - "gpt-35-turbo-instruct" + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", ] + azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') - + raise InvokeAuthorizationError("Invalid api key") + if not prompt: - raise BadRequestError('Invalid prompt') + raise BadRequestError("Invalid prompt") if stream: return MockCompletionsClass.mocked_openai_completion_create_stream(model=model) - - return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) \ No newline at end of file + + return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index eccdbd34795c6e..4138cdd40d822c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -12,48 +12,39 @@ class MockEmbeddingsClass: def create_embeddings( - self: Embeddings, *, + self: Embeddings, + *, input: Union[str, list[str], list[int], list[list[int]]], model: Union[str, Literal["text-embedding-ada-002"]], encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> CreateEmbeddingResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - if encoding_format == 'float': + raise InvokeAuthorizationError("Invalid API key") + + if encoding_format == "float": return CreateEmbeddingResponse( data=[ - Embedding( - embedding=[0.23333 for _ in range(233)], - index=i, - object='embedding' - ) for i in range(len(input)) + Embedding(embedding=[0.23333 for _ in range(233)], index=i, object="embedding") + for i in range(len(input)) ], model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) + usage=Usage(prompt_tokens=2, total_tokens=2), ) - - embeddings = 'VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7' + + embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" data = [] for i, text in enumerate(input): - obj = Embedding( - embedding=[], - index=i, - object='embedding' - ) + obj = Embedding(embedding=[], index=i, object="embedding") obj.embedding = embeddings data.append(obj) @@ -61,10 +52,7 @@ def create_embeddings( return CreateEmbeddingResponse( data=data, model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) - ) \ No newline at end of file + usage=Usage(prompt_tokens=2, total_tokens=2), + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 9466f4bfb8e794..270a88e85ffedd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -10,58 +10,92 @@ class MockModerationClass: - def moderation_create(self: Moderations,*, + def moderation_create( + self: Moderations, + *, input: Union[str, list[str]], model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> ModerationCreateResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') + raise InvokeAuthorizationError("Invalid API key") for text in input: result = [] - if 'kill' in text: + if "kill" in text: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0, - 'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0, - 'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0 + "harassment": 1.0, + "harassment/threatening": 1.0, + "hate": 1.0, + "hate/threatening": 1.0, + "self-harm": 1.0, + "self-harm/instructions": 1.0, + "self-harm/intent": 1.0, + "sexual": 1.0, + "sexual/minors": 1.0, + "violence": 1.0, + "violence/graphic": 1.0, } - result.append(Moderation( - flagged=True, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=True, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) else: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0, - 'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0, - 'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0 + "harassment": 0.0, + "harassment/threatening": 0.0, + "hate": 0.0, + "hate/threatening": 0.0, + "self-harm": 0.0, + "self-harm/instructions": 0.0, + "self-harm/intent": 0.0, + "sexual": 0.0, + "sexual/minors": 0.0, + "violence": 0.0, + "violence/graphic": 0.0, } - result.append(Moderation( - flagged=False, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=False, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) - return ModerationCreateResponse( - id='shiroii kuloko', - model=model, - results=result - ) \ No newline at end of file + return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 0124ac045b6877..cb8f2495438783 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -6,17 +6,18 @@ class MockModelClass: """ - mock class for openai.models.Models + mock class for openai.models.Models """ + def list( self, **kwargs, ) -> list[Model]: return [ Model( - id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ', + id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ", created=int(time()), - object='model', - owned_by='organization:org-123', + object="model", + owned_by="organization:org-123", ) - ] \ No newline at end of file + ] diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index 755fec4c1fbc9e..ef361e86139427 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -9,7 +9,8 @@ class MockSpeech2TextClass: - def speech2text_create(self: Transcriptions, + def speech2text_create( + self: Transcriptions, *, file: FileTypes, model: Union[str, Literal["whisper-1"]], @@ -17,14 +18,12 @@ def speech2text_create(self: Transcriptions, prompt: str | NotGiven = NOT_GIVEN, response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> Transcription: - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - return Transcription( - text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10' - ) \ No newline at end of file + raise InvokeAuthorizationError("Invalid API key") + + return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10") diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 7cb0a1318e9b8d..777737187e259c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -19,40 +19,43 @@ class MockXinferenceClass: - def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: - if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): - raise RuntimeError('404 Not Found') - - if 'generate' == model_uid: + def get_chat_model( + self: Client, model_uid: str + ) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: + if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): + raise RuntimeError("404 Not Found") + + if "generate" == model_uid: return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'chat' == model_uid: + if "chat" == model_uid: return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'embedding' == model_uid: + if "embedding" == model_uid: return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'rerank' == model_uid: + if "rerank" == model_uid: return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - raise RuntimeError('404 Not Found') - + raise RuntimeError("404 Not Found") + def get(self: Session, url: str, **kwargs): response = Response() - if 'v1/models/' in url: + if "v1/models/" in url: # get model uid - model_uid = url.split('/')[-1] or '' - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ - model_uid not in ['generate', 'chat', 'embedding', 'rerank']: + model_uid = url.split("/")[-1] or "" + if not re.match( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid + ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response # check if url is valid - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response - - if model_uid in ['generate', 'chat']: + + if model_uid in ["generate", "chat"]: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "LLM", "address": "127.0.0.1:43877", "accelerators": [ @@ -75,12 +78,12 @@ def get(self: Session, url: str, **kwargs): "revision": null, "context_length": 2048, "replica": 1 - }''' + }""" return response - - elif model_uid == 'embedding': + + elif model_uid == "embedding": response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "embedding", "address": "127.0.0.1:43877", "accelerators": [ @@ -93,51 +96,48 @@ def get(self: Session, url: str, **kwargs): ], "revision": null, "max_tokens": 512 - }''' + }""" return response - - elif 'v1/cluster/auth' in url: + + elif "v1/cluster/auth" in url: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "auth": true - }''' + }""" return response - + def _check_cluster_authenticated(self): self._cluster_authed = True - - def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict: + + def rerank( + self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool + ) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'rerank': - raise RuntimeError('404 Not Found') - - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url): - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "rerank" + ): + raise RuntimeError("404 Not Found") + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): + raise RuntimeError("404 Not Found") if top_n is None: top_n = 1 return { - 'results': [ - { - 'index': i, - 'document': doc, - 'relevance_score': 0.9 - } - for i, doc in enumerate(documents[:top_n]) + "results": [ + {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) ] } - - def create_embedding( - self: RESTfulGenerateModelHandle, - input: Union[str, list[str]], - **kwargs - ) -> dict: + + def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'embedding': - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "embedding" + ): + raise RuntimeError("404 Not Found") if isinstance(input, str): input = [input] @@ -147,32 +147,27 @@ def create_embedding( object="list", model=self._model_uid, data=[ - EmbeddingData( - index=i, - object="embedding", - embedding=[1919.810 for _ in range(768)] - ) + EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) for i in range(ipt_len) ], - usage=EmbeddingUsage( - prompt_tokens=ipt_len, - total_tokens=ipt_len - ) + usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), ) return embedding -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_xinference_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) - monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) - monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) - monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) - monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) + monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) + monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) + monkeypatch.setattr(Session, "get", MockXinferenceClass.get) + monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) + monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index 0d54d97daad24e..8f7e9ec48743bf 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -10,79 +10,60 @@ from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_credentials(setup_anthropic_mock): model = AnthropicLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")} ) -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', + model="claude-instant-1.2", credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), - 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') + "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"), + "anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,18 +79,14 @@ def test_get_num_tokens(): model = AnthropicLargeLanguageModel() num_tokens = model.get_num_tokens( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 7eaa40dfddc2af..6f1e50f431849f 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_anthropic_mock): provider = AnthropicProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/__init__.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py new file mode 100644 index 00000000000000..8655b43d8f0b3d --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -0,0 +1,113 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel +from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_validate_credentials(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="gpt-35-turbo", + credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, + ) + + model.validate_credentials( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + ) + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_invoke_model(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_invoke_stream_model(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = AzureAIStudioLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py new file mode 100644 index 00000000000000..8afe38b09b9f02 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.azure_ai_studio import AzureAIStudioProvider + + +def test_validate_provider_credentials(): + provider = AzureAIStudioProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")} + ) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py new file mode 100644 index 00000000000000..466facc5fffcf6 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel + + +def test_validate_credentials(): + model = AzureAIStudioRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="azure-ai-studio-rerank-v1", + credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + +def test_invoke_model(): + model = AzureAIStudioRerankModel() + + result = model.invoke( + model="azure-ai-studio-rerank-v1", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_JWT_TOKEN"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index 6afec540ade181..8f50ebf7a6d03f 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -17,101 +17,90 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo", + }, ) model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo-instruct", + }, ) model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -122,66 +111,60 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -194,109 +177,87 @@ def test_invoke_stream_chat_model(setup_openai_mock): assert chunk.delta.usage is not None assert chunk.delta.usage.completion_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-4v', + model="gpt-4v", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-4-vision-preview' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-4-vision-preview", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo', + model="gpt-35-turbo", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -308,32 +269,22 @@ def test_get_num_tokens(): model = AzureOpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-35-turbo-instruct', - credentials={ - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-35-turbo-instruct", + credentials={"base_model_name": "gpt-35-turbo-instruct"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt35', - credentials={ - 'base_model_name': 'gpt-35-turbo' - }, + model="gpt35", + credentials={"base_model_name": "gpt-35-turbo"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 8b838eb8fc8183..a1ae2b2e5b740c 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -8,45 +8,43 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "text-embedding-ada-002", + }, ) model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() result = model.invoke( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -58,14 +56,7 @@ def test_get_num_tokens(): model = AzureOpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embedding', - credentials={ - 'base_model_name': 'text-embedding-ada-002' - }, - texts=[ - "hello", - "world" - ] + model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index 1cae9a6dd0962e..ad586102879a10 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -17,111 +17,99 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = BaichuanLarguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='baichuan2-turbo', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') - } + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, ) + def test_invoke_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_with_system_message(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, prompt_messages=[ - SystemPromptMessage( - content='请记住你是Kasumi。' - ), - UserPromptMessage( - content='现在告诉我你是谁?' - ) + SystemPromptMessage(content="请记住你是Kasumi。"), + UserPromptMessage(content="现在告诉我你是谁?"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -131,34 +119,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'with_search_enhance': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "with_search_enhance": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -166,25 +151,22 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True total_message += chunk.delta.message.content - assert '不' not in total_message + assert "不" not in total_message + def test_get_num_tokens(): sleep(3) model = BaichuanLarguageModel() response = model.get_num_tokens( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 9 \ No newline at end of file + assert response == 9 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index 87b3d9a6099839..4036edfb7a7062 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = BaichuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index 1210ebc53d96f8..cbc63f3978fb99 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = BaichuanTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } + model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")} ) @@ -30,44 +22,40 @@ def test_invoke_model(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = BaichuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, texts=[ "hello", @@ -92,8 +80,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py index 20dc11151a70a7..c19ec35a6e45fc 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -13,77 +13,63 @@ def test_validate_credentials(): model = BedrockLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='meta.llama2-13b-chat-v1', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") - } + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, ) + def test_invoke_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens_to_sample': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens_to_sample': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,20 +86,18 @@ def test_get_num_tokens(): model = BedrockLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta.llama2-13b-chat-v1', - credentials = { + model="meta.llama2-13b-chat-v1", + credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py index e53d4c1db2133b..080727829e9e2f 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = BedrockProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index e32f01a315b7d5..418e88874d1572 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -23,79 +23,64 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': 'invalid_key' - } - ) - - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"}) + + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) + -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。' + content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。" ), - UserPromptMessage( - content='波士顿天气如何?' - ) + UserPromptMessage(content="波士顿天气如何?"), ], model_parameters={ - 'temperature': 0, - 'top_p': 1.0, + "temperature": 0, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=True, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, Generator) - + call: LLMResultChunk = None chunks = [] @@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock): break assert call is not None - assert call.delta.message.tool_calls[0].function.name == 'get_current_weather' + assert call.delta.message.tool_calls[0].function.name == "get_current_weather" -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, - prompt_messages=[ - UserPromptMessage( - content='What is the weather like in San Francisco?' - ) - ], + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=False, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 - assert response.message.tool_calls[0].function.name == 'get_current_weather' + assert response.message.tool_calls[0].function.name == "get_current_weather" def test_get_num_tokens(): model = ChatGLMLargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index e9c5c4da751b75..7907805d072772 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -7,19 +7,11 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = ChatGLMProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_base': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_base": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py index 5ce4f8ecfe874f..b7f707e935dbea 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_llm.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_validate_credentials_for_completion_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_completion_model(): model = CohereLargeLanguageModel() - credentials = { - 'api_key': os.environ.get('COHERE_API_KEY') - } + credentials = {"api_key": os.environ.get("COHERE_API_KEY")} result = model.invoke( - model='command-light', + model="command-light", credentials=credentials, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 + assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1 def test_invoke_stream_completion_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -109,28 +71,24 @@ def test_invoke_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'p': 0.99, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "p": 0.99, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -141,24 +99,17 @@ def test_invoke_stream_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -177,32 +128,22 @@ def test_get_num_tokens(): model = CohereLargeLanguageModel() num_tokens = model.get_num_tokens( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 15 @@ -213,25 +154,17 @@ def test_fine_tuned_model(): # test invoke result = model.invoke( - model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'completion' - }, + model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -242,25 +175,17 @@ def test_fine_tuned_chat_model(): # test invoke result = model.invoke( - model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'chat' - }, + model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index a8f56b61943c8b..fb7e6d34984a61 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = CohereProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index 415c5fbfda56d0..a1b6922128570e 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -11,29 +11,17 @@ def test_validate_credentials(): model = CohereRerankModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_model(): model = CohereRerankModel() result = model.invoke( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="rerank-english-v2.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, query="What is the capital of the United States?", docs=[ "Carson City is the capital city of the American state of Nevada. At the 2010 United States " @@ -41,9 +29,9 @@ def test_invoke_model(): "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) " "is the capital of the United States. It is a federal district. The President of the USA and many major " "national government offices are in the territory. This makes it the political center of the United " - "States of America." + "States of America.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py index 5017ba47e11033..ae26d36635d1b5 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = CohereTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } + model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")} ) @@ -30,17 +22,10 @@ def test_invoke_model(): model = CohereTextEmbeddingModel() result = model.invoke( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -52,14 +37,9 @@ def test_get_num_tokens(): model = CohereTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world"], ) assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 00d907d19ef7ef..4d9d490a8720c2 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -16,103 +16,73 @@ from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_credentials(setup_google_mock): model = GoogleLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": "invalid_key"}) + + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) + -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.5, - 'top_p': 1.0, - 'max_tokens_to_sample': 2048 - }, - stop=['How'], + model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.2, - 'top_k': 5, - 'max_tokens_to_sample': 2048 - }, + model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,88 +93,66 @@ def test_invoke_stream_model(setup_google_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), + content=[ + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.' - ), + SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), + content=[ + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ), - AssistantPromptMessage( - content="I see a blue letter 'D' with a gradient from light blue to dark blue." - ), + AssistantPromptMessage(content="I see a blue letter 'D' with a gradient from light blue to dark blue."), UserPromptMessage( content=[ - TextPromptMessageContent( - data="what about now?" - ), + TextPromptMessageContent(data="what about now?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) print(f"resultz: {result.message.content}") @@ -212,23 +160,18 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): assert len(result.message.content) > 0 - def test_get_num_tokens(): model = GoogleLargeLanguageModel() num_tokens = model.get_num_tokens( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 103107ed5ae6c5..c217e4fe058870 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_google_mock): provider = GoogleProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index 28cd0955b33109..6a6cc874fa2f30 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -10,87 +10,75 @@ from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="HuggingFaceH4/zephyr-7b-beta", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='fake-model', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="fake-model", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -286,18 +264,14 @@ def test_get_num_tokens(): model = HuggingfaceHubLargeLanguageModel() num_tokens = model.get_num_tokens( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index d03b3186cb4657..0ee593f38a494a 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key', - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": "invalid_key", + }, ) model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) @@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) @@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -104,18 +98,15 @@ def test_get_num_tokens(): model = HuggingfaceHubTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py b/api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py new file mode 100644 index 00000000000000..b1fa9d5ca5097f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -0,0 +1,70 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( + HuggingfaceTeiTextEmbeddingModel, + TeiHelper, +) +from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): + if MOCK: + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) + yield + + if MOCK: + monkeypatch.undo() + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_tei_mock): + model = HuggingfaceTeiTextEmbeddingModel() + # model name is only used in mock + model_name = "embedding" + + if MOCK: + # TEI Provider will check model type by API endpoint, at real server, the model type is correct. + # So we dont need to check model type here. Only check in mock + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="reranker", + credentials={ + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, + ) + + model.validate_credentials( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, + ) + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_invoke_model(setup_tei_mock): + model = HuggingfaceTeiTextEmbeddingModel() + model_name = "embedding" + + result = model.invoke( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py new file mode 100644 index 00000000000000..45370d9fba41b0 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -0,0 +1,78 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( + HuggingfaceTeiRerankModel, +) +from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper +from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): + if MOCK: + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) + yield + + if MOCK: + monkeypatch.undo() + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_tei_mock): + model = HuggingfaceTeiRerankModel() + # model name is only used in mock + model_name = "reranker" + + if MOCK: + # TEI Provider will check model type by API endpoint, at real server, the model type is correct. + # So we dont need to check model type here. Only check in mock + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="embedding", + credentials={ + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, + ) + + model.validate_credentials( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, + ) + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_invoke_model(setup_tei_mock): + model = HuggingfaceTeiRerankModel() + # model name is only used in mock + model_name = "reranker" + + result = model.invoke( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, + query="Who is Kasumi?", + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py index 305f967ef0a785..b3049a06d9b98a 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py @@ -14,19 +14,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-standard', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -34,23 +30,16 @@ def test_invoke_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -61,23 +50,15 @@ def test_invoke_stream_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,19 +74,17 @@ def test_get_num_tokens(): model = HunyuanLargeLanguageModel() num_tokens = model.get_num_tokens( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py index bdec3d0e22d6d5..e3748c2ce713d4 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py @@ -10,16 +10,11 @@ def test_validate_provider_credentials(): provider = HunyuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } - ) + provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}) provider.validate_provider_credentials( credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py index 7ae6c0e45635db..69d14dffeebf35 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -12,19 +12,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-embedding', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -32,47 +28,43 @@ def test_invoke_model(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = HunyuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, texts=[ "hello", @@ -97,8 +89,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 2b43248388e845..e3b6128c59d8df 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = JinaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index ac175661746a17..290735ec49e625 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = JinaTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } + model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")} ) @@ -30,15 +22,12 @@ def test_invoke_model(): model = JinaTextEmbeddingModel() result = model.invoke( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +39,11 @@ def test_get_num_tokens(): model = JinaTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/localai/test_embedding.py b/api/tests/integration_tests/model_runtime/localai/test_embedding.py index e05345ee56e67d..7fd9f2b3000a31 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/localai/test_embedding.py @@ -1,4 +1,4 @@ """ - LocalAI Embedding Interface is temporarily unavailable due to - we could not find a way to test it for now. -""" \ No newline at end of file +LocalAI Embedding Interface is temporarily unavailable due to +we could not find a way to test it for now. +""" diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index 6f421403d4d517..aa5436c34fc7dd 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,28 +102,21 @@ def test_invoke_stream_completion_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_stream_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -154,64 +126,48 @@ def test_invoke_stream_chat_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = LocalAILanguageModel() num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py index 99847bc8528a0a..13c7df6d1473b0 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-v2-m3', + model="bge-reranker-v2-m3", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_rerank_model(): model = LocalaiRerankModel() response = model.invoke( - model='bge-reranker-base', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -45,43 +44,38 @@ def test_invoke_rerank_model(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(response, RerankResult) assert len(response.docs) == 3 + def test__invoke(): model = LocalaiRerankModel() # Test case 1: Empty docs result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 0 # Test case 2: Valid invocation result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -91,12 +85,12 @@ def test__invoke(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 3 - assert all(isinstance(doc, RerankDocument) for doc in result.docs) \ No newline at end of file + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py index 3fd2ebed4f0be1..91b7a5752ce973 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -10,19 +10,9 @@ def test_validate_credentials(): model = LocalAISpeech2text() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': 'invalid_url' - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}) def test_invoke_model(): @@ -32,23 +22,21 @@ def test_invoke_model(): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, + model="whisper-1", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' \ No newline at end of file + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index 6f4b8a163f96da..cf2a28eb9eb2fe 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -12,54 +12,47 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embo-01', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + model="embo-01", + credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")}, ) model.validate_credentials( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): model = MinimaxTextEmbeddingModel() result = model.invoke( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 16 + def test_get_num_tokens(): model = MinimaxTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index 570e4901a9e7e2..aacde04d326caf 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -17,79 +17,70 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = MinimaxLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='abab5.5-chat', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': 'invalid_key' - } + model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"} ) model.validate_credentials( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5-chat', + model="abab5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -99,34 +90,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'plugin_web_search': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -134,25 +122,22 @@ def test_invoke_with_search(): total_message += chunk.delta.message.content assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True - assert '参考资料' in total_message + assert "参考资料" in total_message + def test_get_num_tokens(): sleep(3) model = MinimaxLargeLanguageModel() response = model.get_num_tokens( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 30 \ No newline at end of file + assert response == 30 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 4c5462c6dff551..575ed13eef124a 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -12,14 +12,14 @@ def test_validate_provider_credentials(): with pytest.raises(CredentialsValidateFailedError): provider.validate_provider_credentials( credentials={ - 'minimax_api_key': 'hahahaha', - 'minimax_group_id': '123', + "minimax_api_key": "hahahaha", + "minimax_group_id": "123", } ) provider.validate_provider_credentials( credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'), + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), } ) diff --git a/api/tests/integration_tests/model_runtime/novita/test_llm.py b/api/tests/integration_tests/model_runtime/novita/test_llm.py index 4ebc68493f26d5..35fa0dc1904f7b 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_llm.py +++ b/api/tests/integration_tests/model_runtime/novita/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'completion' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_p': 0.5, - 'max_tokens': 10, + "temperature": 1.0, + "top_p": 0.5, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="novita" + user="novita", ) assert isinstance(response, LLMResult) @@ -70,27 +58,17 @@ def test_invoke_stream_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'max_tokens': 100 - }, + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100}, stream=True, - user="novita" + user="novita", ) assert isinstance(response, Generator) @@ -105,18 +83,16 @@ def test_get_num_tokens(): model = NovitaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta-llama/llama-3-8b-instruct', + model="meta-llama/llama-3-8b-instruct", credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/novita/test_provider.py b/api/tests/integration_tests/model_runtime/novita/test_provider.py index bb3f19dc851ea5..191af99db20bd9 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_provider.py +++ b/api/tests/integration_tests/model_runtime/novita/test_provider.py @@ -10,12 +10,10 @@ def test_validate_provider_credentials(): provider = NovitaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 272e639a8ac11e..58a1339f506458 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -20,23 +20,23 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) @@ -44,26 +44,17 @@ def test_invoke_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -74,29 +65,22 @@ def test_invoke_stream_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -111,26 +95,17 @@ def test_invoke_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -141,29 +116,22 @@ def test_invoke_stream_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -178,29 +146,26 @@ def test_invoke_completion_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -212,29 +177,26 @@ def test_invoke_chat_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -246,18 +208,14 @@ def test_get_num_tokens(): model = OllamaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py index c5f5918235d3a2..3c4f740a4fd09c 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -12,21 +12,21 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 4096, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, ) @@ -34,17 +34,14 @@ def test_invoke_model(): model = OllamaEmbeddingModel() result = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -56,16 +53,13 @@ def test_get_num_tokens(): model = OllamaEmbeddingModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index bf4ac53579fb6e..3b3ea9ec800cbb 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -28,92 +28,61 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": "invalid_key"}) - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-davinci-003", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-davinci-003", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1 + assert model._num_tokens_from_string("gpt-3.5-turbo-instruct", result.message.content) == 1 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', + model="gpt-3.5-turbo-instruct", credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_organization': os.environ.get('OPENAI_ORGANIZATION'), - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_key": os.environ.get("OPENAI_API_KEY"), + "openai_organization": os.environ.get("OPENAI_ORGANIZATION"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -124,166 +93,131 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-4-vision-preview', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-4-vision-preview", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -302,68 +236,46 @@ def test_get_num_tokens(): model = OpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 72 -@pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat", "remote"]], indirect=True) def test_fine_tuned_models(setup_openai_mock): model = OpenAILargeLanguageModel() - remote_models = model.remote_models(credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }) + remote_models = model.remote_models(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) if not remote_models: assert isinstance(remote_models, list) @@ -379,29 +291,23 @@ def test_fine_tuned_models(setup_openai_mock): # test invoke result = model.invoke( model=llm_model.model, - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) + def test__get_num_tokens_by_gpt2(): model = OpenAILargeLanguageModel() - num_tokens = model._get_num_tokens_by_gpt2('Hello World!') + num_tokens = model._get_num_tokens_by_gpt2("Hello World!") assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 04f9b9f33b3537..6de262471798ad 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -7,48 +7,37 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAIModerationModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAIModerationModel() result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="hello", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) assert result is False result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="i will kill you", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index 5314bffbdf37b1..4d56cfcf3c25f0 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = OpenAIProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index f1a5c4fd23f091..aa92c8b61fb684 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -7,26 +7,17 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAISpeech2TextModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAISpeech2TextModel() @@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="whisper-1", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index e2c4c74ee7636b..f5dd73f2d4cd60 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -8,42 +8,27 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAITextEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -55,15 +40,9 @@ def test_get_num_tokens(): model = OpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world" - ] + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index c8335085695fd5..f2302ef05e06de 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -23,21 +23,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"}, ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + }, ) @@ -45,28 +41,26 @@ def test_invoke_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'completion' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -77,29 +71,27 @@ def test_invoke_stream_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat', - 'stream_mode_delimiter': '\\n\\n' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + "stream_mode_delimiter": "\\n\\n", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools(): model = OAIAPICompatLargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', + model="gpt-3.5-turbo", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'mode': 'chat' + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1024 - }, + model_parameters={"temperature": 0.0, "max_tokens": 1024}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -207,19 +183,14 @@ def test_get_num_tokens(): model = OAIAPICompatLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py new file mode 100644 index 00000000000000..cf805eafff4968 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import ( + OAICompatSpeech2TextModel, +) + + +def test_validate_credentials(): + model = OAICompatSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="whisper-1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"}, + ) + + model.validate_credentials( + model="whisper-1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, + ) + + +def test_invoke_model(): + model = OAICompatSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, + file=file, + user="abc-123", + ) + + assert isinstance(result, str) + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 77d27ec1615fe0..052b41605f6da2 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -12,27 +12,23 @@ Using OpenAI's API as testing endpoint """ + def test_validate_credentials(): model = OAICompatEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - - } + model="text-embedding-ada-002", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184}, ) model.validate_credentials( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - } + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, + }, ) @@ -40,19 +36,14 @@ def test_invoke_model(): model = OAICompatEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -64,16 +55,13 @@ def test_get_num_tokens(): model = OAICompatEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/embeddings", + "context_size": 8184, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) - assert num_tokens == 2 \ No newline at end of file + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 9eb05a111d9340..14d47217af62c8 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -12,17 +12,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"), + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) @@ -30,33 +30,28 @@ def test_invoke_model(): model = OpenLLMTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = OpenLLMTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 853a0fbe3c9a16..35939e3cfe8bfd 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'invalid_key', - } + "server_url": "invalid_key", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) + def test_invoke_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -84,21 +78,18 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = OpenLLMLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 3 \ No newline at end of file + assert response == 3 diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py index 8f1fb4c4ad7990..ce4876a73a740e 100644 --- a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -70,27 +58,22 @@ def test_invoke_stream_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,18 +88,16 @@ def test_get_num_tokens(): model = OpenRouterLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/mixtral-8x7b-instruct', + model="mistralai/mixtral-8x7b-instruct", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index e248f064c05de3..b940005b715760 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -14,19 +14,19 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": "invalid_key", + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) @@ -34,27 +34,25 @@ def test_invoke_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,27 +63,25 @@ def test_invoke_stream_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct-v0.1', + model="mistralai/mixtral-8x7b-instruct-v0.1", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,19 +96,17 @@ def test_get_num_tokens(): model = ReplicateLargeLanguageModel() num_tokens = model.get_num_tokens( - model='', + model="", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 5708ec9e5a219e..397715f2252083 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -12,19 +12,19 @@ def test_validate_credentials_one(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": "invalid_key", + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) @@ -33,19 +33,19 @@ def test_validate_credentials_two(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": "invalid_key", + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) @@ -53,16 +53,13 @@ def test_invoke_model_one(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -74,16 +71,13 @@ def test_invoke_model_two(): model = ReplicateEmbeddingModel() result = model.invoke( - model='andreasjansson/clip-features', + model="andreasjansson/clip-features", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -95,16 +89,13 @@ def test_invoke_model_three(): model = ReplicateEmbeddingModel() result = model.invoke( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -116,16 +107,13 @@ def test_invoke_model_four(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -137,15 +125,12 @@ def test_get_num_tokens(): model = ReplicateEmbeddingModel() num_tokens = model.get_num_tokens( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py index 639227e7450343..9f0b439d6c32a1 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -10,10 +10,6 @@ def test_validate_provider_credentials(): provider = SageMakerProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py index c67849dd798883..d5a6798a1ef735 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -12,11 +12,11 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -25,7 +25,7 @@ def test_validate_credentials(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) @@ -33,11 +33,11 @@ def test_invoke_model(): model = SageMakerRerankModel() result = model.invoke( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -46,7 +46,7 @@ def test_invoke_model(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py index e817e8f04ab67c..e4e404c7a86ae6 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -11,45 +11,23 @@ def test_validate_credentials(): model = SageMakerEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='bge-m3', - credentials={ - } - ) + model.validate_credentials(model="bge-m3", credentials={}) - model.validate_credentials( - model='bge-m3-embedding', - credentials={ - } - ) + model.validate_credentials(model="bge-m3-embedding", credentials={}) def test_invoke_model(): model = SageMakerEmbeddingModel() - result = model.invoke( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - "hello", - "world" - ], - user="abc-123" - ) + result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123") assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 + def test_get_num_tokens(): model = SageMakerEmbeddingModel() - num_tokens = model.get_num_tokens( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - ] - ) + num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[]) assert num_tokens == 0 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py index befdd82352780e..f47c9c558808af 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = SiliconflowLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")}) def test_invoke_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = SiliconflowLargeLanguageModel() num_tokens = model.get_num_tokens( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py index 7b9211a5dbe9c7..8f70210b7a2ace 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = SiliconflowProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py new file mode 100644 index 00000000000000..ad794613f91013 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -0,0 +1,47 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.rerank.rerank import SiliconflowRerankModel + + +def test_validate_credentials(): + model = SiliconflowRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SiliconflowRerankModel() + + result = model.invoke( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py new file mode 100644 index 00000000000000..0502ba5ab404bc --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel + + +def test_validate_credentials(): + model = SiliconflowSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="iic/SenseVoiceSmall", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="iic/SenseVoiceSmall", + credentials={"api_key": os.environ.get("API_KEY")}, + ) + + +def test_invoke_model(): + model = SiliconflowSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file + ) + + assert isinstance(result, str) + assert result == "1,2,3,4,5,6,7,8,9,10." diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py new file mode 100644 index 00000000000000..ab143c10613a88 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py @@ -0,0 +1,60 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.text_embedding.text_embedding import ( + SiliconflowTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = SiliconflowTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="BAAI/bge-large-zh-v1.5", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="BAAI/bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SiliconflowTextEmbeddingModel() + + result = model.invoke( + model="BAAI/bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + texts=[ + "hello", + "world", + ], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + + +def test_get_num_tokens(): + model = SiliconflowTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="BAAI/bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 706316449d3142..4fe2fd8c0a3eac 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -13,20 +13,15 @@ def test_validate_credentials(): model = SparkLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='spark-1.5', - credentials={ - 'app_id': 'invalid_key' - } - ) + model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"}) model.validate_credentials( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - } + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, ) @@ -34,24 +29,17 @@ def test_invoke_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -62,23 +50,16 @@ def test_invoke_stream_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100 + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -94,20 +75,18 @@ def test_get_num_tokens(): model = SparkLargeLanguageModel() num_tokens = model.get_num_tokens( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8e22815a86fc84..9da0df6bb3d556 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = SparkProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py index d703147d638be3..c03b1bae1f1574 100644 --- a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -21,40 +21,22 @@ def test_validate_credentials(): model = StepfunLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } - ) + model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}) + def test_invoke_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['Hi'], + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["Hi"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,24 +47,17 @@ def test_invoke_stream_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,10 +73,7 @@ def test_get_customizable_model_schema(): model = StepfunLargeLanguageModel() schema = model.get_customizable_model_schema( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } + model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")} ) assert isinstance(schema, AIModelEntity) @@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools(): model = StepfunLargeLanguageModel() result = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in Shanghai?", - ) + ), ], - model_parameters={ - 'temperature': 0.9, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.9, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) - assert len(result.message.tool_calls) > 0 \ No newline at end of file + assert len(result.message.tool_calls) > 0 diff --git a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py index fd8aa3f610648c..0ec4b0b7243176 100644 --- a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py +++ b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py @@ -24,13 +24,8 @@ def test_get_models(): providers = factory.get_models( model_type=ModelType.LLM, provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) logger.debug(providers) @@ -44,29 +39,21 @@ def test_get_models(): assert provider_model.model_type == ModelType.LLM providers = factory.get_models( - provider='openai', + provider="openai", provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) assert len(providers) == 1 assert isinstance(providers[0], SimpleProviderEntity) - assert providers[0].provider == 'openai' + assert providers[0].provider == "openai" def test_provider_credentials_validate(): factory = ModelProviderFactory() factory.provider_credentials_validate( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) @@ -79,4 +66,4 @@ def test__get_model_provider_map(): logger.debug(model_provider.provider_instance) assert len(model_providers) >= 1 - assert isinstance(model_providers['openai'], ModelProviderExtension) + assert isinstance(model_providers["openai"], ModelProviderExtension) diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index 698f53451779b8..06ebc2a82dc754 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -19,76 +19,61 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) + def test_invoke_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,22 +83,21 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) + def test_get_num_tokens(): model = TogetherAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 81fb676018b992..61650735f2ad3f 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -13,18 +13,10 @@ def test_validate_credentials(): model = TongyiLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"}) model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) @@ -32,22 +24,13 @@ def test_invoke_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +41,12 @@ def test_invoke_stream_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +62,14 @@ def test_get_num_tokens(): model = TongyiLargeLanguageModel() num_tokens = model.get_num_tokens( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 6145c1dc37d00b..0bc96c84e73195 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -10,12 +10,8 @@ def test_validate_provider_credentials(): provider = TongyiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py index 1b0a38d5d15a00..905e7907fde5a8 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py @@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"): response = model.invoke( model=model_name, - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ - UserPromptMessage( - content='output json data with format `{"data": "test", "code": 200, "msg": "success"}' - ) + UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}') ], model_parameters={ - 'temperature': 0.5, - 'max_tokens': 50, - 'response_format': 'JSON', + "temperature": 0.5, + "max_tokens": 50, + "response_format": "JSON", }, stream=True, - user="abc-123" + user="abc-123", ) print("=====================================") print(response) @@ -81,4 +77,4 @@ def is_json(s): json.loads(s) except ValueError: return False - return True \ No newline at end of file + return True diff --git a/api/tests/integration_tests/model_runtime/upstage/test_llm.py b/api/tests/integration_tests/model_runtime/upstage/test_llm.py index c35580a8b1ec00..bc7517acbe2601 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_llm.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_llm.py @@ -26,151 +26,113 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): # model name to gpt-3.5-turbo because of mocking - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'upstage_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"}) model.validate_credentials( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -189,57 +151,36 @@ def test_get_num_tokens(): model = UpstageLargeLanguageModel() num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 13 num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 106 diff --git a/api/tests/integration_tests/model_runtime/upstage/test_provider.py b/api/tests/integration_tests/model_runtime/upstage/test_provider.py index c33eef49b2a79e..9d83779aa00a49 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_provider.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = UpstageProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py index 54135a0e748d40..8c83172fa3ff7e 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py @@ -8,41 +8,31 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = UpstageTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': 'invalid_key' - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"} ) model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = UpstageTextEmbeddingModel() result = model.invoke( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -54,14 +44,11 @@ def test_get_num_tokens(): model = UpstageTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 5 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py index 3b399d604ec910..f831c063a42630 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -14,26 +14,26 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - 'base_model_name': 'Doubao-embedding', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + "base_model_name": "Doubao-embedding", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, ) @@ -42,20 +42,17 @@ def test_invoke_model(): model = VolcengineMaaSTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -67,19 +64,16 @@ def test_get_num_tokens(): model = VolcengineMaaSTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py index 63835d0263ead0..8ff9c414046e7d 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + }, ) @@ -40,28 +40,24 @@ def test_invoke_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) @@ -73,28 +69,24 @@ def test_invoke_stream_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -102,29 +94,24 @@ def test_invoke_stream_model(): assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len( - chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True def test_get_num_tokens(): model = VolcengineMaaSLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py new file mode 100644 index 00000000000000..ac38340aecf7d2 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -0,0 +1,69 @@ +import os +from time import sleep + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel + + +def test_invoke_embedding_v1(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="embedding-v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_en(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="bge-large-en", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_zh(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="bge-large-zh", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_tao_8k(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="tao-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 164e8253d966ae..e2e58f15e025d8 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -17,161 +17,125 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = ErnieBotLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, ) + def test_invoke_model_ernie_bot(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_turbo(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-turbo', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-turbo", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_8k(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_4(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-4', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-4", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-3.5-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-3.5-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -181,63 +145,48 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_model_with_system(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='你是Kasumi' - ), - UserPromptMessage( - content='你是谁?' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) - assert 'kasumi' in response.message.content.lower() + assert "kasumi" in response.message.content.lower() + def test_invoke_with_search(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'disable_search': True, + "temperature": 0.7, + "top_p": 1.0, + "disable_search": True, }, stop=[], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -247,25 +196,19 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True # there should be 对不起、我不能、不支持…… - assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message) + assert "不" in total_message or "抱歉" in total_message or "无法" in total_message + def test_get_num_tokens(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.get_num_tokens( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 10 \ No newline at end of file + assert response == 10 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 8922aa18681087..337c3d2a8010dd 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -10,16 +10,8 @@ def test_validate_provider_credentials(): provider = WenxinProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha', - 'secret_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"}) provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index f0a5151f3dbac0..8e778d005a4bc3 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -8,61 +8,57 @@ from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceTextEmbeddingModel() result = model.invoke( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = XinferenceTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 47730406de94b8..48d1ae323d6ab6 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='aaaaa', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + """ Funtion calling of xinference does not support stream mode currently """ @@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # ) # assert isinstance(response, Generator) - + # call: LLMResultChunk = None # chunks = [] @@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # assert response.usage.total_tokens > 0 # assert response.message.tool_calls[0].function.name == 'get_current_weather' -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='alapaca', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = XinferenceAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index 9012c16a7e6d3b..71ac4eef7c22be 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -8,44 +8,42 @@ from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-base', - credentials={ - 'server_url': 'awdawdaw', - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + model="bge-reranker-base", + credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")}, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceRerankModel() result = model.invoke( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/zhinao/__init__.py b/api/tests/integration_tests/model_runtime/zhinao/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py new file mode 100644 index 00000000000000..4ca1b864764818 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py @@ -0,0 +1,73 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhinao.llm.llm import ZhinaoLargeLanguageModel + + +def test_validate_credentials(): + model = ZhinaoLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) + + +def test_invoke_model(): + model = ZhinaoLargeLanguageModel() + + response = model.invoke( + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = ZhinaoLargeLanguageModel() + + response = model.invoke( + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = ZhinaoLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py new file mode 100644 index 00000000000000..c22f797919597c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhinao.zhinao import ZhinaoProvider + + +def test_validate_provider_credentials(): + provider = ZhinaoProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 0f92b50cb0f350..20380513eaa789 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -18,41 +18,22 @@ def test_validate_credentials(): model = ZhipuAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['How'], + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -63,21 +44,12 @@ def test_invoke_stream_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,63 +65,45 @@ def test_get_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 + def test_get_tools_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='tools', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="tools", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) ], prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 88 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 51b9cccf2ea752..cb5bc0b20aafc1 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhipuaiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 7308c5729669c8..9c97c91ecbdd94 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -11,34 +11,19 @@ def test_validate_credentials(): model = ZhipuAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAITextEmbeddingModel() result = model.invoke( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ], - user="abc-123" + model="text_embedding", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +35,7 @@ def test_get_num_tokens(): model = ZhipuAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 41bb3daeb5e531..4dfc530010fa93 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,20 +7,17 @@ class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ request = httpx.Request( - method, - url, - params=kwargs.get('params'), - headers=kwargs.get('headers'), - cookies=kwargs.get('cookies') + method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies") ) - data = kwargs.get('data', None) - resp = json.dumps(data).encode('utf-8') if data else b'OK' + data = kwargs.get("data", None) + resp = json.dumps(data).encode("utf-8") if data else b"OK" response = httpx.Response( status_code=200, request=request, diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index ba14d365c5d2d6..83f4d70ce9ac2f 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -10,6 +10,7 @@ "user1": ["Go for a run", "Read a book"], } + class TodosResource(Resource): def get(self, username): todos = todos_data.get(username, []) @@ -32,7 +33,8 @@ def delete(self, username): return {"error": "Invalid todo index"}, 400 -api.add_resource(TodosResource, '/todos/') -if __name__ == '__main__': +api.add_resource(TodosResource, "/todos/") + +if __name__ == "__main__": app.run(port=5003, debug=True) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index f6e7b153dde7f6..09729a961eff33 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -3,37 +3,40 @@ from tests.integration_tests.tools.__mock.http import setup_http_mock tool_bundle = { - 'server_url': 'http://www.example.com/{path_param}', - 'method': 'post', - 'author': '', - 'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'}, - {'in': 'query', 'name': 'query_param'}, - {'in': 'cookie', 'name': 'cookie_param'}, - {'in': 'header', 'name': 'header_param'}, - ], - 'requestBody': { - 'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}} - }, - 'parameters': [] + "server_url": "http://www.example.com/{path_param}", + "method": "post", + "author": "", + "openapi": { + "parameters": [ + {"in": "path", "name": "path_param"}, + {"in": "query", "name": "query_param"}, + {"in": "cookie", "name": "cookie_param"}, + {"in": "header", "name": "header_param"}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}} + }, + }, + "parameters": [], } parameters = { - 'path_param': 'p_param', - 'query_param': 'q_param', - 'cookie_param': 'c_param', - 'header_param': 'h_param', - 'body_param': 'b_param', + "path_param": "p_param", + "query_param": "q_param", + "cookie_param": "c_param", + "header_param": "h_param", + "body_param": "b_param", } def test_api_tool(setup_http_mock): - tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'})) + tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) headers = tool.assembling_request(parameters) response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert '/p_param' == response.request.url.path - assert b'query_param=q_param' == response.request.url.query - assert 'h_param' == response.request.headers.get('header_param') - assert 'application/json' == response.request.headers.get('content-type') - assert 'cookie_param=c_param' == response.request.headers.get('cookie') - assert 'b_param' in response.content.decode() + assert "/p_param" == response.request.url.path + assert b"query_param=q_param" == response.request.url.query + assert "h_param" == response.request.headers.get("header_param") + assert "application/json" == response.request.headers.get("content-type") + assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py index 2811bc816dadbe..2dfce749b3e16f 100644 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -7,16 +7,17 @@ ToolManager.clear_builtin_providers_cache() provider_generator = ToolManager.list_builtin_providers() -@pytest.mark.parametrize('name', provider_names) + +@pytest.mark.parametrize("name", provider_names) def test_tool_providers(benchmark, name): """ Test that all tool providers can be loaded """ - + def test(generator): try: return next(generator) except StopIteration: return None - - benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) \ No newline at end of file + + benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) diff --git a/api/tests/integration_tests/utils/parent_class.py b/api/tests/integration_tests/utils/parent_class.py index 39fc95256e512b..6a6de1cc41aaf2 100644 --- a/api/tests/integration_tests/utils/parent_class.py +++ b/api/tests/integration_tests/utils/parent_class.py @@ -3,4 +3,4 @@ def __init__(self, name): self.name = name def get_name(self): - return self.name \ No newline at end of file + return self.name diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 256c9a911f104f..7d32f5ae66f5df 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -7,26 +7,26 @@ def test_loading_subclass_from_source(): current_path = os.getcwd() module = load_single_subclass_from_source( - module_name='ChildClass', - script_path=os.path.join(current_path, 'child_class.py'), - parent_type=ParentClass) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass + ) + assert module and module.__name__ == "ChildClass" def test_load_import_module_from_source(): current_path = os.getcwd() module = import_module_from_source( - module_name='ChildClass', - py_file_path=os.path.join(current_path, 'child_class.py')) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py") + ) + assert module and module.__name__ == "ChildClass" def test_lazy_loading_subclass_from_source(): current_path = os.getcwd() clz = load_single_subclass_from_source( - module_name='LazyLoadChildClass', - script_path=os.path.join(current_path, 'lazy_load_class.py'), + module_name="LazyLoadChildClass", + script_path=os.path.join(current_path, "lazy_load_class.py"), parent_type=ParentClass, - use_lazy_loader=True) - instance = clz('dify') - assert instance.get_name() == 'dify' + use_lazy_loader=True, + ) + instance = clz("dify") + assert instance.get_name() == "dify" diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index f8165cba9468aa..571c1e3d440508 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,11 +13,15 @@ class MockTcvectordbClass: - - def VectorDBClient(self, url=None, username='', key='', - read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, - timeout=5, - adapter: HTTPAdapter = None): + def VectorDBClient( + self, + url=None, + username="", + key="", + read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, + timeout=5, + adapter: HTTPAdapter = None, + ): self._conn = None self._read_consistency = read_consistency @@ -26,105 +30,96 @@ def list_databases(self) -> list[Database]: Database( conn=self._conn, read_consistency=self._read_consistency, - name='dify', - )] + name="dify", + ) + ] def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: return [] def drop_collection(self, name: str, timeout: Optional[float] = None): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def create_collection( - self, - name: str, - shard: int, - replicas: int, - description: str, - index: Index, - embedding: Embedding = None, - timeout: float = None, + self, + name: str, + shard: int, + replicas: int, + description: str, + index: Index, + embedding: Embedding = None, + timeout: float = None, ) -> Collection: - return Collection(self, name, shard, replicas, description, index, embedding=embedding, - read_consistency=self._read_consistency, timeout=timeout) - - def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: - collection = Collection( + return Collection( self, name, - shard=1, - replicas=2, - description=name, - timeout=timeout + shard, + replicas, + description, + index, + embedding=embedding, + read_consistency=self._read_consistency, + timeout=timeout, ) + + def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: + collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) return collection def collection_upsert( - self, - documents: list[Document], - timeout: Optional[float] = None, - build_index: bool = True, - **kwargs + self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def collection_search( - self, - vectors: list[list[float]], - filter: Filter = None, - params=None, - retrieve_vector: bool = False, - limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + vectors: list[list[float]], + filter: Filter = None, + params=None, + retrieve_vector: bool = False, + limit: int = 10, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[list[dict]]: - return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]] + return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] def collection_query( - self, - document_ids: Optional[list] = None, - retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + document_ids: Optional[list] = None, + retrieve_vector: bool = False, + limit: Optional[int] = None, + offset: Optional[int] = None, + filter: Optional[Filter] = None, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[dict]: - return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}] + return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] def collection_delete( - self, - document_ids: list[str] = None, - filter: Filter = None, - timeout: float = None, + self, + document_ids: list[str] = None, + filter: Filter = None, + timeout: float = None, ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient) - monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases) - monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection) - monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections) - monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection) - monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection) - monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert) - monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search) - monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query) - monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) + monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) + monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) + monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) + monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) + monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) + monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) yield diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py index d6067af73b70bd..970b98edc3d83b 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -26,6 +26,7 @@ def __init__(self): def run_all_tests(self): self.vector.delete() return super().run_all_tests() - + + def test_chroma_vector(setup_mock_redis): - AnalyticdbVectorTest().run_all_tests() \ No newline at end of file + AnalyticdbVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/tests/integration_tests/vdb/chroma/test_chroma.py index 033f9a54da678c..ac7b5cbda45b23 100644 --- a/api/tests/integration_tests/vdb/chroma/test_chroma.py +++ b/api/tests/integration_tests/vdb/chroma/test_chroma.py @@ -14,13 +14,13 @@ def __init__(self): self.vector = ChromaVector( collection_name=self.collection_name, config=ChromaConfig( - host='localhost', + host="localhost", port=8000, tenant=chromadb.DEFAULT_TENANT, database=chromadb.DEFAULT_DATABASE, auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", auth_credentials="difyai123456", - ) + ), ) def search_by_full_text(self): diff --git a/api/tests/integration_tests/vdb/elasticsearch/__init__.py b/api/tests/integration_tests/vdb/elasticsearch/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py new file mode 100644 index 00000000000000..2a0c1bb0389187 --- /dev/null +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -0,0 +1,20 @@ +from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class ElasticSearchVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] + self.vector = ElasticSearchVector( + index_name=self.collection_name.lower(), + config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + attributes=self.attributes, + ) + + +def test_elasticsearch_vector(setup_mock_redis): + ElasticSearchVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 9c0917ef307d0f..7b5f19ea629f3c 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -12,11 +12,11 @@ def __init__(self): self.vector = MilvusVector( collection_name=self.collection_name, config=MilvusConfig( - host='localhost', + host="localhost", port=19530, - user='root', - password='Milvus', - ) + user="root", + password="Milvus", + ), ) def search_by_full_text(self): @@ -25,7 +25,7 @@ def search_by_full_text(self): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/tests/integration_tests/vdb/myscale/test_myscale.py index b6260d549ac99a..55b2fde4276105 100644 --- a/api/tests/integration_tests/vdb/myscale/test_myscale.py +++ b/api/tests/integration_tests/vdb/myscale/test_myscale.py @@ -21,7 +21,7 @@ def __init__(self): ) def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index ea1e05da9007f4..a99b81d41eba78 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -29,54 +29,55 @@ def setup_method(self): self.example_doc_id = "example_doc_id" self.vector = OpenSearchVector( collection_name=self.collection_name, - config=OpenSearchConfig( - host='localhost', - port=9200, - user='admin', - password='password', - secure=False - ) + config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), ) self.vector._client = MagicMock() - @pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [ - ({ - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] - } - }, 1, "example_doc_id"), - ({ - 'hits': { - 'total': {'value': 0}, - 'hits': [] - } - }, 0, None) - ]) + @pytest.mark.parametrize( + "search_response, expected_length, expected_doc_id", + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): self.vector._client.search.return_value = search_response hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == expected_length if expected_length > 0: - assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id + assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id def test_search_by_vector(self): vector = [0.1] * 128 mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ + "hits": { + "total": {"value": 1}, + "hits": [ { - '_source': { + "_source": { Field.CONTENT_KEY.value: get_example_text(), - Field.METADATA_KEY.value: {"document_id": self.example_doc_id} + Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, }, - '_score': 1.0 + "_score": 1.0, } - ] + ], } } self.vector._client.search.return_value = mock_response @@ -85,53 +86,45 @@ def test_search_by_vector(self): print("Hits by vector:", hits_by_vector) print("Expected document ID:", self.example_doc_id) - print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits") + print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" - assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ - f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" + assert ( + hits_by_vector[0].metadata["document_id"] == self.example_doc_id + ), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" def test_get_ids_by_metadata_field(self): - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" def test_add_texts(self): - self.vector._client.index.return_value = {'result': 'created'} + self.vector._client.index.return_value = {"result": "created"} doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" + @pytest.mark.usefixtures("setup_mock_redis") class TestOpenSearchVectorWithRedis: @@ -141,11 +134,11 @@ def setup_method(self): def test_search_by_full_text(self): self.tester.setup_method() search_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], } } expected_length = 1 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py index e6ce8aab3db173..6b33217d157ea7 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -12,13 +12,13 @@ def __init__(self): self.vector = PGVectoRS( collection_name=self.collection_name.lower(), config=PgvectoRSConfig( - host='localhost', + host="localhost", port=5431, - user='postgres', - password='difyai123456', - database='dify', + user="postgres", + password="difyai123456", + database="dify", ), - dim=128 + dim=128, ) def search_by_full_text(self): @@ -27,8 +27,9 @@ def search_by_full_text(self): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 + def test_pgvecot_rs(setup_mock_redis): PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py index 851599c7cec6b9..c5a986b7479259 100644 --- a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py +++ b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py @@ -21,10 +21,6 @@ def __init__(self): ), ) - def search_by_full_text(self): - hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) - assert len(hits_by_full_text) == 0 - def test_pgvector(setup_mock_redis): PGVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 34beb25d450618..61d9a9e712aade 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -8,14 +8,14 @@ class QdrantVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = QdrantVector( collection_name=self.collection_name, group_id=self.dataset_id, config=QdrantConfig( - endpoint='http://localhost:6333', - api_key='difyai123456', - ) + endpoint="http://localhost:6333", + api_key="difyai123456", + ), ) diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py index 8937fe0ea16c50..1b9466e27f4c8f 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py @@ -7,18 +7,22 @@ mock_client = MagicMock() mock_client.list_databases.return_value = [{"name": "test"}] + class TencentVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.vector = TencentVector("dify", TencentConfig( - url="http://127.0.0.1", - api_key="dify", - timeout=30, - username="dify", - database="dify", - shard=1, - replicas=2, - )) + self.vector = TencentVector( + "dify", + TencentConfig( + url="http://127.0.0.1", + api_key="dify", + timeout=30, + username="dify", + database="dify", + shard=1, + replicas=2, + ), + ) def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) @@ -28,8 +32,6 @@ def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 -def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock): - TencentVectorTest().run_all_tests() - - +def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock): + TencentVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index cb35822709fa1a..a11cd225b3ba37 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -10,7 +10,7 @@ def get_example_text() -> str: - return 'test_text' + return "test_text" def get_example_document(doc_id: str) -> Document: @@ -21,7 +21,7 @@ def get_example_document(doc_id: str) -> Document: "doc_hash": doc_id, "document_id": doc_id, "dataset_id": doc_id, - } + }, ) return doc @@ -45,7 +45,7 @@ class AbstractVectorTest: def __init__(self): self.vector = None self.dataset_id = str(uuid.uuid4()) - self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test' + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.example_doc_id = str(uuid.uuid4()) self.example_embedding = [1.001 * i for i in range(128)] @@ -58,12 +58,12 @@ def create_vector(self) -> None: def search_by_vector(self): hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 - assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 1 - assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id def delete_vector(self): self.vector.delete() @@ -76,14 +76,14 @@ def add_texts(self) -> list[str]: documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] embeddings = [self.example_embedding] * batch_size self.vector.add_texts(documents=documents, embeddings=embeddings) - return [doc.metadata['doc_id'] for doc in documents] + return [doc.metadata["doc_id"] for doc in documents] def text_exists(self): assert self.vector.text_exists(self.example_doc_id) def get_ids_by_metadata_field(self): with pytest.raises(NotImplementedError): - self.vector.get_ids_by_metadata_field(key='key', value='value') + self.vector.get_ids_by_metadata_field(key="key", value="value") def run_all_tests(self): self.create_vector() diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 18e00dbeddbd77..2a5320c7d5e752 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -10,15 +10,15 @@ @pytest.fixture def tidb_vector(): return TiDBVector( - collection_name='test_collection', + collection_name="test_collection", config=TiDBVectorConfig( host="xxx.eu-central-1.xxx.aws.tidbcloud.com", port="4000", user="xxx.root", password="xxxxxx", database="dify", - program_name="langgenius/dify" - ) + program_name="langgenius/dify", + ), ) @@ -40,7 +40,7 @@ def search_by_full_text(self): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 0 @@ -50,12 +50,12 @@ def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_ @pytest.fixture def mock_session(): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session: + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session: yield mock_session @pytest.fixture def setup_tidbvector_mock(tidb_vector, mock_session): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'): - with patch.object(tidb_vector._engine, 'connect'): + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"): + with patch.object(tidb_vector._engine, "connect"): yield tidb_vector diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 3d540cee32bfe2..a6f55420d312ee 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -8,14 +8,14 @@ class WeaviateVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = WeaviateVector( collection_name=self.collection_name, config=WeaviateConfig( - endpoint='http://localhost:8080', - api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih', + endpoint="http://localhost:8080", + api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih", ), - attributes=self.attributes + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 13f992136e8fa8..6fb8c86b82a132 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -6,24 +6,22 @@ from jinja2 import Template from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.entities import CodeDependency -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], - code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: + def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict: # invoke directly match language: case CodeLanguage.PYTHON3: - return { - "result": 3 - } + return {"result": 3} case CodeLanguage.JINJA2: - return { - "result": Template(code).render(inputs) - } + return {"result": Template(code).render(inputs)} + case _: + raise Exception("Language not supported") + @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index beb5c040097b65..cfc47bcad4675e 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -6,38 +6,32 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ - if url == 'http://404.com': - response = httpx.Response( - status_code=404, - request=httpx.Request(method, url), - content=b'Not Found' - ) + if url == "http://404.com": + response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found") return response # get data, files - data = kwargs.get('data', None) - files = kwargs.get('files', None) + data = kwargs.get("data", None) + files = kwargs.get("files", None) if data is not None: - resp = dumps(data).encode('utf-8') + resp = dumps(data).encode("utf-8") elif files is not None: - resp = dumps(files).encode('utf-8') + resp = dumps(files).encode("utf-8") else: - resp = b'OK' + resp = b"OK" response = httpx.Response( - status_code=200, - request=httpx.Request(method, url), - headers=kwargs.get('headers', {}), - content=resp + status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp ) return response diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index ae6e7ceaa71295..44dcf9a10fa2d7 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -2,10 +2,10 @@ from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor -CODE_LANGUAGE = 'unsupported_language' +CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): with pytest.raises(CodeExecutionException) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code='', inputs={}) - assert str(e.value) == f'Unsupported language {CODE_LANGUAGE}' + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 2d798eb9c2b361..09fcb68cf032d4 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -9,8 +9,8 @@ def test_javascript_plain(): code = 'console.log("Hello World")' - result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result_message == 'Hello World\n' + result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result_message == "Hello World\n" def test_javascript_json(): @@ -18,22 +18,17 @@ def test_javascript_json(): obj = {'Hello': 'World'} console.log(JSON.stringify(obj)) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello":"World"}\n' def test_javascript_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(), - inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} - - -def test_javascript_list_default_available_packages(): - packages = JavascriptCodeProvider.get_default_available_packages() - - # no default packages available for javascript - assert len(packages) == 0 + language=CODE_LANGUAGE, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} def test_javascript_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index 425f4cbdd4b9be..94903cf79688e5 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -7,21 +7,24 @@ def test_jinja2(): - template = 'Hello {{template}}' - inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8') - code = (Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._code_placeholder, template) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, - preload=Jinja2TemplateTransformer.get_preload_script(), - code=code) - assert result == '<>Hello World<>\n' + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" def test_jinja2_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code='Hello {{template}}', inputs={'template': 'World'}) - assert result == {'result': 'Hello World'} + language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} def test_jinja2_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index d265011d4c29f6..cbe4a5d335a7c8 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -10,8 +10,8 @@ def test_python3_plain(): code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result == 'Hello World\n' + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == "Hello World\n" def test_python3_json(): @@ -19,23 +19,15 @@ def test_python3_json(): import json print(json.dumps({'Hello': 'World'})) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello": "World"}\n' def test_python3_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} - - -def test_python3_list_default_available_packages(): - packages = Python3CodeProvider.get_default_available_packages() - assert len(packages) > 0 - assert {'requests', 'httpx'}.issubset(p['name'] for p in packages) - - # check JSON serializable - assert len(str(json.dumps(packages))) > 0 + language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} + ) + assert result == {"result": "HelloWorld"} def test_python3_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 5c952585208d7b..6f5421e108f0c7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -9,137 +9,134 @@ from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -CODE_MAX_STRING_LENGTH = int(getenv('CODE_MAX_STRING_LENGTH', '10000')) +CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { - 'outputs': { - 'result': { - 'type': 'number', + "id": "1", + "data": { + "outputs": { + "result": { + "type": "number", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] == 3 + assert result.outputs["result"] == 3 assert result.error is None -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code_output_validator(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "result": { "type": "string", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == 'Output variable `result` must be a string' + assert result.error == "Output variable `result` must be a string" + def test_execute_code_output_validator_depth(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "string_validator": { "type": "string", @@ -168,29 +165,26 @@ def main(args1: int, args2: int) -> dict: "depth": { "type": "number", } - } + }, } - } - } - } + }, + }, + }, }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result @@ -199,14 +193,7 @@ def main(args1: int, args2: int) -> dict: "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -218,14 +205,7 @@ def main(args1: int, args2: int) -> dict: "string_validator": 1, "number_array_validator": ["1", "2", "3", "3.333"], "string_array_validator": [1, 2, 3], - "object_validator": { - "result": "1", - "depth": { - "depth": { - "depth": "1" - } - } - } + "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, } # validate @@ -238,34 +218,20 @@ def main(args1: int, args2: int) -> dict: "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate with pytest.raises(ValueError): node._transform_result(result, node.node_data.outputs) - + # construct result result = { "number_validator": 1, "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333] * 2000, "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -274,58 +240,59 @@ def main(args1: int, args2: int) -> dict: def test_execute_code_output_object_list(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "object_list": { "type": "array[object]", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + ] } # validate @@ -333,13 +300,18 @@ def main(args1: int, args2: int) -> dict: # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }, 1] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + 1, + ] } # validate diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index a1354bd6a5f220..acb616b3256ac3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,322 +9,337 @@ from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock BASIC_NODE_DATA = { - 'tenant_id': '1', - 'app_id': '1', - 'workflow_id': '1', - 'user_id': '1', - 'user_from': UserFrom.ACCOUNT, - 'invoke_from': InvokeFrom.WEB_APP, + "tenant_id": "1", + "app_id": "1", + "workflow_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.WEB_APP, } # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(['a', 'b123', 'args1'], 1) -pool.add(['a', 'b123', 'args2'], 2) +pool.add(["a", "b123", "args1"], 1) +pool.add(["a", "b123", "args2"], 2) -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_no_auth(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'custom', - 'api_key': 'Auth', - 'header': 'X-Auth', + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "custom", + "api_key": "Auth", + "header": "X-Auth", + }, }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com/{{#a.b123.args2#}}', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com/{{#a.b123.args2#}}", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}", + "params": "A:b\nTemplate:{{#a.b123.args2#}}", + "body": None, }, - 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', - 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'Template=2' in data - assert 'X-Header: 123' in data - assert 'X-Header2: 2' in data + assert "?A=b" in data + assert "Template=2" in data + assert "X-Header: 123" in data + assert "X-Header2: 2" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'json', - 'data': '{"a": "{{#a.b123.args1#}}"}' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert '{"a": "1"}' in data - assert 'X-Header: 123' in data + assert "X-Header: 123" in data def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'x-www-form-urlencoded', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'a=1&b=2' in data - assert 'X-Header: 123' in data + assert "a=1&b=2" in data + assert "X-Header: 123" in data def test_form_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'form-data', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert 'form-data; name="a"' in data - assert '1' in data + assert "1" in data assert 'form-data; name="b"' in data - assert '2' in data - assert 'X-Header: 123' in data + assert "2" in data + assert "X-Header: 123" in data def test_none_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'none', - 'data': '123123123' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "none", "data": "123123123"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'X-Header: 123' in data - assert '123123123' not in data + assert "X-Header: 123" in data + assert "123123123" not in data def test_mock_404(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://404.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://404.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "body": None, + "params": "", + "headers": "X-Header:123", }, - 'body': None, - 'params': '', - 'headers': 'X-Header:123', - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert 404 == resp.get('status_code') - assert 'Not Found' in resp.get('body') + assert 404 == resp.get("status_code") + assert "Not Found" in resp.get("body") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, - }, - 'params': 'Referer:http://example1.com\nRedirect:http://example2.com', - 'headers': 'Referer:http://example3.com\nRedirect:http://example4.com', - 'body': { - 'type': 'form-data', - 'data': 'Referer:http://example5.com\nRedirect:http://example6.com' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "params": "Referer:http://example1.com\nRedirect:http://example2.com", + "headers": "Referer:http://example3.com\nRedirect:http://example4.com", + "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert urlencode({'Redirect': 'http://example2.com'}) in result.process_data.get('request') - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get('request') - assert 'http://example3.com' == resp.get('headers').get('referer') + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") + assert "http://example3.com" == resp.get("headers").get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index ac704e4eaf54df..6bab83a0191bca 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -10,8 +10,8 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -23,90 +23,71 @@ from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm(setup_openai_mock): node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'prompt_template': [ - { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.' - }, - { - 'role': 'user', - 'text': '{{#sys.query#}}' - } + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') - - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - - provider_instance = ModelProviderFactory().get_provider_instance('openai') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") + + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -118,112 +99,97 @@ def test_execute_llm(setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['text'] is not None - assert result.outputs['usage']['total_tokens'] > 0 + assert result.outputs["text"] is not None + assert result.outputs["usage"]["total_tokens"] > 0 + -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - 'prompt_config': { - 'jinja2_variables': [{ - 'variable': 'sys_query', - 'value_selector': ['sys', 'query'] - }, { - 'variable': 'output', - 'value_selector': ['abc', 'output'] - }] - }, - 'prompt_template': [ + "prompt_template": [ { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', - 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', - 'edition_type': 'jinja2' + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", }, { - 'role': 'user', - 'text': '{{#sys.query#}}', - 'jinja2_text': '{{sys_query}}', - 'edition_type': 'basic' - } + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') - - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - - provider_instance = ModelProviderFactory().get_provider_instance('openai') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") + + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -235,5 +201,5 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) \ No newline at end of file + assert "sunny" in json.dumps(result.process_data) + assert "what's the weather today?" in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 312ad47026beb5..ca2bae5c536090 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -12,8 +12,8 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -26,29 +26,25 @@ def get_mocked_fetch_model_config( - provider: str, model: str, mode: str, + provider: str, + model: str, + mode: str, credentials: dict, ): provider_instance = ModelProviderFactory().get_provider_instance(provider) model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) model_config = ModelConfigWithCredentialsEntity( @@ -58,268 +54,268 @@ def get_mocked_fetch_model_config( credentials=credentials, parameters={}, model_schema=model_type_instance.get_model_schema(model), - provider_model_bundle=provider_model_bundle + provider_model_bundle=provider_model_bundle, ) return MagicMock(return_value=(model_instance, model_config)) + def get_mocked_fetch_memory(memory_text: str): class MemoryMock: - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None): + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ): return memory_text return MagicMock(return_value=MemoryMock()) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ Test function calling for parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'instruction': '', - 'reasoning_mode': 'function_call', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "instruction": "", + "reasoning_mode": "function_call", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'function_call', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "function_call", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None process_data = result.process_data - process_data.get('prompts') + process_data.get("prompts") + + for prompt in process_data.get("prompts"): + if prompt.get("role") == "system": + assert "what's the weather in SF" in prompt.get("text") - for prompt in process_data.get('prompts'): - if prompt.get('role') == 'system': - assert 'what\'s the weather in SF' in prompt.get('text') -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo-instruct", + mode="completion", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - assert len(result.process_data.get('prompts')) == 1 - assert 'SF' in result.process_data.get('prompts')[0].get('text') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert len(result.process_data.get("prompts")) == 1 + assert "SF" in result.process_data.get("prompts")[0].get("text") + def test_extract_json_response(): """ @@ -327,35 +323,30 @@ def test_extract_json_response(): """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) result = node._extract_complete_json_response(""" @@ -363,86 +354,80 @@ def test_extract_json_response(): { "location": "kawaii" } - hello world. + hello world. """) - assert result['location'] == 'kawaii' + assert result["location"] == "kawaii" -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': { - 'window': { - 'enabled': True, - 'size': 50 - } - }, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": {"window": {"enabled": True, "size": 50}}, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) - node._fetch_memory = get_mocked_fetch_memory('customized memory') + node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") latest_role = None for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') - elif prompt.get('role') == 'system': - assert 'customized memory' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + elif prompt.get("role") == "system": + assert "customized memory" in prompt.get("text") if latest_role is not None: - assert latest_role != prompt.get('role') + assert latest_role != prompt.get("role") - if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') \ No newline at end of file + if prompt.get("role") in ["user", "assistant"]: + latest_role = prompt.get("role") diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 781dfbc50fdba3..617b6370c9f410 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -8,42 +8,39 @@ from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = '''{{args2}}''' + code = """{{args2}}""" node = TemplateTransformNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER, config={ - 'id': '1', - 'data': { - 'title': '123', - 'variables': [ + "id": "1", + "data": { + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'template': code, - } - } + "template": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 3) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 3) + # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['output'] == '3' + assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 01d62280e837b3..29c1efa8e78b2a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -7,78 +7,79 @@ def test_tool_variable_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], '1+1') + pool.add(["1", "123", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'variable', - 'value': ['1', '123', 'args1'], + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] + def test_tool_mixed_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', 'args1'], '1+1') + pool.add(["1", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'mixed', - 'value': '{{#1.args1#}}', + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "mixed", + "value": "{{#1.args1#}}", } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] \ No newline at end of file + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 949a5a17693457..3f639ccacc48f5 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -3,21 +3,26 @@ import pytest from flask import Flask +from yarl import URL from configs.app_config import DifyConfig -EXAMPLE_ENV_FILENAME = '.env' +EXAMPLE_ENV_FILENAME = ".env" @pytest.fixture def example_env_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME) - file_path.write_text(dedent( - """ + file_path.write_text( + dedent( + """ CONSOLE_API_URL=https://example.com CONSOLE_WEB_URL=https://example.com - """)) + HTTP_REQUEST_MAX_WRITE_TIMEOUT=30 + """ + ) + ) return str(file_path) @@ -29,7 +34,7 @@ def test_dify_config_undefined_entry(example_env_file): # entries not defined in app settings with pytest.raises(TypeError): # TypeError: 'AppSettings' object is not subscriptable - assert config['LOG_LEVEL'] == 'INFO' + assert config["LOG_LEVEL"] == "INFO" def test_dify_config(example_env_file): @@ -37,47 +42,56 @@ def test_dify_config(example_env_file): config = DifyConfig(_env_file=example_env_file) # constant values - assert config.COMMIT_SHA == '' + assert config.COMMIT_SHA == "" # default values - assert config.EDITION == 'SELF_HOSTED' + assert config.EDITION == "SELF_HOSTED" assert config.API_COMPRESSION_ENABLED is False assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 + # annotated field with default value + assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 60 + + # annotated field with configured value + assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_flask_configs(example_env_file): - flask_app = Flask('app') + flask_app = Flask("app") # clear system environment variables os.environ.clear() flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings - assert config['LOG_LEVEL'] == 'INFO' - assert config['COMMIT_SHA'] == '' - assert config['EDITION'] == 'SELF_HOSTED' - assert config['API_COMPRESSION_ENABLED'] is False - assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0 - assert config['TESTING'] == False + assert config["LOG_LEVEL"] == "INFO" + assert config["COMMIT_SHA"] == "" + assert config["EDITION"] == "SELF_HOSTED" + assert config["API_COMPRESSION_ENABLED"] is False + assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0 + assert config["TESTING"] == False # value from env file - assert config['CONSOLE_API_URL'] == 'https://example.com' + assert config["CONSOLE_API_URL"] == "https://example.com" # fallback to alias choices value as CONSOLE_API_URL - assert config['FILES_URL'] == 'https://example.com' + assert config["FILES_URL"] == "https://example.com" - assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify' - assert config['SQLALCHEMY_ENGINE_OPTIONS'] == { - 'connect_args': { - 'options': '-c timezone=UTC', + assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify" + assert config["SQLALCHEMY_ENGINE_OPTIONS"] == { + "connect_args": { + "options": "-c timezone=UTC", }, - 'max_overflow': 10, - 'pool_pre_ping': False, - 'pool_recycle': 3600, - 'pool_size': 30, + "max_overflow": 10, + "pool_pre_ping": False, + "pool_recycle": 3600, + "pool_size": 30, } - assert config['CONSOLE_WEB_URL']=='https://example.com' - assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com'] - assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*'] + assert config["CONSOLE_WEB_URL"] == "https://example.com" + assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"] + assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"] + + assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://sandbox:8194/" + assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://sandbox:8194/v1" diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 85321ee3741dc2..0824c8e9e978e2 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -3,305 +3,156 @@ import pytest from core.app.segments import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneSegment, ObjectSegment, SecretVariable, StringVariable, factory, ) +from core.app.segments.exc import VariableError def test_string_variable(): - test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, StringVariable) def test_integer_variable(): - test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} + test_data = {"value_type": "number", "name": "test_int", "value": 42} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, IntegerVariable) def test_float_variable(): - test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, FloatVariable) def test_secret_variable(): - test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, SecretVariable) def test_invalid_value_type(): - test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} - with pytest.raises(ValueError): + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} + with pytest.raises(VariableError): factory.build_variable_from_mapping(test_data) def test_build_a_blank_string(): result = factory.build_variable_from_mapping( { - 'value_type': 'string', - 'name': 'blank', - 'value': '', + "value_type": "string", + "name": "blank", + "value": "", } ) assert isinstance(result, StringVariable) - assert result.value == '' + assert result.value == "" def test_build_a_object_variable_with_none_value(): var = factory.build_segment( { - 'key1': None, + "key1": None, } ) assert isinstance(var, ObjectSegment) - assert isinstance(var.value['key1'], NoneSegment) + assert var.value["key1"] is None def test_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'test_object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, }, } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value['key1'], StringVariable) - assert isinstance(variable.value['key2'], IntegerVariable) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) def test_array_string_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[string]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ - { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayStringVariable) - assert isinstance(variable.value[0], StringVariable) - assert isinstance(variable.value[1], StringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) def test_array_number_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[number]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ - { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 2.0, - 'description': 'Description of the variable.', - }, + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + 1, + 2.0, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayNumberVariable) - assert isinstance(variable.value[0], IntegerVariable) - assert isinstance(variable.value[1], FloatVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) def test_array_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[object]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - }, + "key1": "text", + "key2": 1, }, { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - }, + "key1": "text", + "key2": 1, }, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayObjectVariable) - assert isinstance(variable.value[0], ObjectSegment) - assert isinstance(variable.value[1], ObjectSegment) - assert isinstance(variable.value[0].value['key1'], StringVariable) - assert isinstance(variable.value[0].value['key2'], IntegerVariable) - assert isinstance(variable.value[1].value['key1'], StringVariable) - assert isinstance(variable.value[1].value['key2'], IntegerVariable) - - -def test_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'file', - 'name': 'test_file', - 'description': 'Description of the variable.', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, FileVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) -def test_array_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'array[file]', - 'name': 'test_array_file', - 'description': 'Description of the variable.', - 'value': [ +def test_variable_cannot_large_than_5_kb(): + with pytest.raises(VariableError): + factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'name': 'file', - 'value_type': 'file', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - }, - { - 'id': str(uuid4()), - 'name': 'file', - 'value_type': 'file', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - }, - ], - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, ArrayFileVariable) - assert isinstance(variable.value[0], FileVariable) - assert isinstance(variable.value[1], FileVariable) + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 6, + } + ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 414404b7d0362a..7cc339d21200a1 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,26 +1,26 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey def test_segment_group_to_text(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[ - SecretVariable(name='secret_key', value='fake-secret-key'), + SecretVariable(name="secret_key", value="fake-secret-key"), ], ) - variable_pool.add(('node_id', 'custom_query'), 'fake-user-query') + variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( - 'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.' + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.' + assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." assert ( segments_group.log == f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}." @@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group(): user_inputs={}, environment_variables=[], ) - template = 'Hello, world!' + template = "Hello, world!" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, world!' - assert segments_group.log == 'Hello, world!' + assert segments_group.text == "Hello, world!" + assert segments_group.log == "Hello, world!" def test_convert_variable_to_segment_group(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[], ) - template = '{{#sys.user_id#}}' + template = "{{#sys.user_id#}}" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'fake-user-id' - assert segments_group.log == 'fake-user-id' - assert segments_group.value == [StringSegment(value='fake-user-id')] + assert segments_group.text == "fake-user-id" + assert segments_group.log == "fake-user-id" + assert segments_group.value == [StringSegment(value="fake-user-id")] diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index e3f513971a4810..b3f0ae626cf24b 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -13,70 +13,60 @@ def test_frozen_variables(): - var = StringVariable(name='text', value='text') + var = StringVariable(name="text", value="text") with pytest.raises(ValidationError): - var.value = 'new value' + var.value = "new value" - int_var = IntegerVariable(name='integer', value=42) + int_var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): int_var.value = 100 - float_var = FloatVariable(name='float', value=3.14) + float_var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): float_var.value = 2.718 - secret_var = SecretVariable(name='secret', value='secret_value') + secret_var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): - secret_var.value = 'new_secret_value' + secret_var.value = "new_secret_value" def test_variable_value_type_immutable(): with pytest.raises(ValidationError): - StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text') + StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text") with pytest.raises(ValidationError): - StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'}) + StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"}) - var = IntegerVariable(name='integer', value=42) + var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = SecretVariable(name='secret', value='secret_value') + var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) def test_object_variable_to_object(): var = ObjectVariable( - name='object', + name="object", value={ - 'key1': ObjectVariable( - name='object', - value={ - 'key2': StringVariable(name='key2', value='value2'), - }, - ), - 'key2': ArrayAnyVariable( - name='array', - value=[ - StringVariable(name='key5_1', value='value5_1'), - IntegerVariable(name='key5_2', value=42), - ObjectVariable(name='key5_3', value={}), - ], - ), + "key1": { + "key2": "value2", + }, + "key2": ["value5_1", 42, {}], }, ) assert var.to_object() == { - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': [ - 'value5_1', + "key2": [ + "value5_1", 42, {}, ], @@ -84,11 +74,11 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var = StringVariable(name='text', value='text') - assert var.to_object() == 'text' - var = IntegerVariable(name='integer', value=42) + var = StringVariable(name="text", value="text") + assert var.to_object() == "text" + var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) assert var.to_object() == 3.14 - var = SecretVariable(name='secret', value='secret_value') - assert var.to_object() == 'secret_value' + var = SecretVariable(name="secret", value="secret_value") + assert var.to_object() == "secret_value" diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d917bb10036faa..7a0bc70c63eeab 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -4,17 +4,17 @@ from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request -@patch('httpx.request') +@patch("httpx.request") def test_successful_request(mock_request): mock_response = MagicMock() mock_response.status_code = 200 mock_request.return_value = mock_response - response = make_request('GET', 'http://example.com') + response = make_request("GET", "http://example.com") assert response.status_code == 200 -@patch('httpx.request') +@patch("httpx.request") def test_retry_exceed_max_retries(mock_request): mock_response = MagicMock() mock_response.status_code = 500 @@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request): mock_request.side_effect = side_effects try: - make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) raise AssertionError("Expected Exception not raised") except Exception as e: assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" -@patch('httpx.request') +@patch("httpx.request") def test_retry_logic_success(mock_request): side_effects = [] @@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request): mock_request.side_effect = side_effects - response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES) assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 - assert mock_request.call_args_list[0][1].get('method') == 'GET' + assert mock_request.call_args_list[0][1].get("method") == "GET" diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/tests/unit_tests/core/model_runtime/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py new file mode 100644 index 00000000000000..5b159b49b61f37 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -0,0 +1,75 @@ +import numpy as np + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import ( + TextEmbedding, + WenxinTextEmbeddingModel, +) + + +def test_max_chunks(): + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += len(text) + + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + model = "embedding-v1" + credentials = { + "api_key": "xxxx", + "secret_key": "yyyy", + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + max_chunks = embedding_model._get_max_chunks(model, credentials) + embedding_model._create_text_embedding = _create_text_embedding + + texts = ["0123456789" for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") + assert len(result.embeddings) == max_chunks * 2 + + +def test_context_size(): + def get_num_tokens_by_gpt2(text: str) -> int: + return GPT2Tokenizer.get_num_tokens(text) + + def mock_text(token_size: int) -> str: + _text = "".join(["0" for i in range(token_size)]) + num_tokens = get_num_tokens_by_gpt2(_text) + ratio = int(np.floor(len(_text) / num_tokens)) + m_text = "".join([_text for i in range(ratio)]) + return m_text + + model = "embedding-v1" + credentials = { + "api_key": "xxxx", + "secret_key": "yyyy", + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += get_num_tokens_by_gpt2(text) + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + embedding_model._create_text_embedding = _create_text_embedding + text = mock_text(context_size * 2) + assert get_num_tokens_by_gpt2(text) == context_size * 2 + + texts = [text] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") + assert result.usage.tokens == context_size diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index fd284488b548fe..24bbde6d4ebb66 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,8 +2,8 @@ import pytest -from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.app.app_config.entities import ModelConfigEntity +from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage from core.prompt.advanced_prompt_transform import AdvancedPromptTransform @@ -14,39 +14,24 @@ def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_config = CompletionModelPromptTemplate( - text=prompt_template - ) + prompt_template_config = CompletionModelPromptTemplate(text=prompt_template) memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix( - user="Human", - assistant="Assistant" - ), - window=MemoryConfig.WindowConfig( - enabled=False - ) + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), ) - inputs = { - "name": "John" - } + inputs = {"name": "John"} files = [] context = "I am superman." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages(): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 1 - assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ - "#context#": context, - "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " - f"{prompt.content}" for prompt in history_prompt_messages]), - **inputs, - }) + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format( + { + "#context#": context, + "#histories#": "\n".join( + [ + f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}" + for prompt in history_prompt_messages + ] + ), + **inputs, + } + ) def test__get_chat_model_prompt_messages(get_chat_model_args): @@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): files = [] query = "Hi2." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi1."), - AssistantPromptMessage(content="Hello1!") - ] + history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert prompt_messages[5].content == query @@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): @@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg image_config={ "detail": "high", } - ) + ), ) ] @@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 assert prompt_messages[3].content[1].data == files[0].url @@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg @pytest.fixture def get_chat_model_args(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) prompt_messages = [ ChatModelMessage( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM ), - ChatModelMessage( - text="Hi.", - role=PromptMessageRole.USER - ), - ChatModelMessage( - text="Hello!", - role=PromptMessageRole.ASSISTANT - ) + ChatModelMessage(text="Hi.", role=PromptMessageRole.USER), + ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT), ] - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "I am superman." diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 9de268d7624744..0fd176e65d02f8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -18,27 +18,28 @@ def test_get_prompt(): prompt_messages = [ - SystemPromptMessage(content='System Template'), - UserPromptMessage(content='User Query'), + SystemPromptMessage(content="System Template"), + UserPromptMessage(content="User Query"), ] history_messages = [ - SystemPromptMessage(content='System Prompt 1'), - UserPromptMessage(content='User Prompt 1'), - AssistantPromptMessage(content='Assistant Thought 1'), - ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), - ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), - SystemPromptMessage(content='System Prompt 2'), - UserPromptMessage(content='User Prompt 2'), - AssistantPromptMessage(content='Assistant Thought 2'), - ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), - ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), - UserPromptMessage(content='User Prompt 3'), - AssistantPromptMessage(content='Assistant Thought 3'), + SystemPromptMessage(content="System Prompt 1"), + UserPromptMessage(content="User Prompt 1"), + AssistantPromptMessage(content="Assistant Thought 1"), + ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"), + ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"), + SystemPromptMessage(content="System Prompt 2"), + UserPromptMessage(content="User Prompt 2"), + AssistantPromptMessage(content="Assistant Thought 2"), + ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"), + ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"), + UserPromptMessage(content="User Prompt 3"), + AssistantPromptMessage(content="Assistant Thought 3"), ] # use message number instead of token for testing def side_effect_get_num_tokens(*args): return len(args[2]) + large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) @@ -46,20 +47,17 @@ def side_effect_get_num_tokens(*args): provider_model_bundle_mock.model_type_instance = large_language_model_mock model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.model = 'openai' + model_config_mock.model = "openai" model_config_mock.credentials = {} model_config_mock.provider_model_bundle = provider_model_bundle_mock - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) transform = AgentHistoryPromptTransform( model_config=model_config_mock, prompt_messages=prompt_messages, history_messages=history_messages, - memory=memory + memory=memory, ) max_token_limit = 5 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 2bcc6f42927edc..89c14463bbfb94 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -12,19 +12,15 @@ def test__calculate_rest_token(): model_schema_mock = MagicMock(spec=AIModelEntity) parameter_rule_mock = MagicMock(spec=ParameterRule) - parameter_rule_mock.name = 'max_tokens' - model_schema_mock.parameter_rules = [ - parameter_rule_mock - ] - model_schema_mock.model_properties = { - ModelPropertyKey.CONTEXT_SIZE: 62 - } + parameter_rule_mock.name = "max_tokens" + model_schema_mock.parameter_rules = [parameter_rule_mock] + model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens.return_value = 6 provider_mock = MagicMock(spec=ProviderEntity) - provider_mock.provider = 'openai' + provider_mock.provider = "openai" provider_configuration_mock = MagicMock(spec=ProviderConfiguration) provider_configuration_mock.provider = provider_mock @@ -35,11 +31,9 @@ def test__calculate_rest_token(): provider_model_bundle_mock.configuration = provider_configuration_mock model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.model = 'gpt-4' + model_config_mock.model = "gpt-4" model_config_mock.credentials = {} - model_config_mock.parameters = { - 'max_tokens': 50 - } + model_config_mock.parameters = {"max_tokens": 50} model_config_mock.model_schema = model_schema_mock model_config_mock.provider_model_bundle = provider_model_bundle_mock @@ -49,8 +43,10 @@ def test__calculate_rest_token(): rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) # Validate based on the mock configuration and expected logic - expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] - - model_config_mock.parameters['max_tokens'] - - large_language_model_mock.get_num_tokens.return_value) + expected_rest_tokens = ( + model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters["max_tokens"] + - large_language_model_mock.get_num_tokens.return_value + ) assert rest_tokens == expected_rest_tokens assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 6d6363610bb16a..c32fc2bc34813d 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_baichuan_chat_app_prompt_template_with_pcqm(): @@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_common_completion_app_prompt_template_with_pcq(): @@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_baichuan_completion_app_prompt_template_with_pcq(): @@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - print(prompt_template['prompt_template'].template) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + print(prompt_template["prompt_template"].template) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_q(): @@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] - assert prompt_template['special_variable_keys'] == ['#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"] + assert prompt_template["special_variable_keys"] == ["#query#"] def test_get_common_chat_app_prompt_template_with_cq(): @@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_p(): @@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p(): query_in_prompt=False, with_memory_prompt=False, ) - assert prompt_template['prompt_template'].template == pre_prompt + '\n' - assert prompt_template['custom_variable_keys'] == ['name'] - assert prompt_template['special_variable_keys'] == [] + assert prompt_template["prompt_template"].template == pre_prompt + "\n" + assert prompt_template["custom_variable_keys"] == ["name"] + assert prompt_template["special_variable_keys"] == [] def test__get_chat_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" memory_mock = MagicMock(spec=TokenBufferMemory) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory_mock.get_history_prompt_messages.return_value = history_prompt_messages prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( @@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages(): files=[], context=context, memory=memory_mock, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages(): with_memory_prompt=False, ) - full_inputs = {**inputs, '#context#': context} - real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + full_inputs = {**inputs, "#context#": context} + real_system_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt @@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( @@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages(): files=[], context=context, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages(): with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( - max_token_limit=2000, - human_prefix=prompt_rules.get("human_prefix", "Human"), - ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") - )} - real_prompt = prompt_template['prompt_template'].format(full_inputs) + prompt_rules = prompt_template["prompt_rules"] + full_inputs = { + **inputs, + "#context#": context, + "#query#": query, + "#histories#": memory.get_history_prompt_text( + max_token_limit=2000, + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), + ), + } + real_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_rules.get('stops') + assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 9e43b23658f41a..8d735cae86d947 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -5,20 +5,15 @@ def test_default_value(): - valid_config = { - 'host': 'localhost', - 'port': 19530, - 'user': 'root', - 'password': 'Milvus' - } + valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: MilvusConfig(**config) - assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required' + assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" config = MilvusConfig(**valid_config) assert config.secure is False - assert config.database == 'default' + assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index a8bba11e16db1f..d5a1d8f436c75a 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -9,19 +9,17 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): url = "https://firecrawl.dev" - api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-' - base_url = 'https://api.firecrawl.dev' - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=base_url) + api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" + base_url = "https://api.firecrawl.dev" + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "maxDepth": 1, "limit": 1, - 'returnOnlyUrls': False, - + "returnOnlyUrls": False, } } mocked_firecrawl = { diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index b231fe479b829b..eea584a2f8edc6 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -8,11 +8,8 @@ extractor = notion_extractor.NotionExtractor( - notion_workspace_id='x', - notion_obj_id='x', - notion_page_type='page', - tenant_id='x', - notion_access_token='x') + notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" +) def _generate_page(page_title: str): @@ -21,16 +18,10 @@ def _generate_page(page_title: str): "id": page_id, "properties": { "Page": { - "type": "title", - "title": [ - { - "type": "text", - "text": {"content": page_title}, - "plain_text": page_title - } - ] + "type": "title", + "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], } - } + }, } @@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str): return { "object": "block", "id": block_id, - "parent": { - "type": "page_id", - "page_id": page_id - }, + "parent": {"type": "page_id", "page_id": page_id}, "type": block_type, "has_children": False, block_type: { @@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str): { "type": "text", "text": {"content": block_text}, - "plain_text": block_text, - }] - } - } + "plain_text": block_text, + } + ] + }, + } def _mock_response(data): @@ -63,7 +52,7 @@ def _mock_response(data): def _remove_multiple_new_lines(text): - while '\n\n' in text: + while "\n\n" in text: text = text.replace("\n\n", "\n") return text.strip() @@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text): def test_notion_page(mocker): texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]) - ], - "next_cursor": None + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]), + ], + "next_cursor": None, } mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) page_docs = extractor._load_data_as_documents(page_id, "page") assert len(page_docs) == 1 content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1' + assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" def test_notion_database(mocker): @@ -93,10 +82,10 @@ def test_notion_database(mocker): mocked_notion_database = { "object": "list", "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None + "next_cursor": None, } mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) database_docs = extractor._load_data_as_documents(database_id, "database") assert len(database_docs) == 1 content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == '\n'.join([f'Page:{i}' for i in page_title_list]) + assert content == "\n".join([f"Page:{i}" for i in page_title_list]) diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 3024a54a4d325c..2808b5b0fad1a7 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -10,36 +10,24 @@ @pytest.fixture def lb_model_manager(): load_balancing_configs = [ - ModelLoadBalancingConfiguration( - id='id1', - name='__inherit__', - credentials={} - ), - ModelLoadBalancingConfiguration( - id='id2', - name='first', - credentials={"openai_api_key": "fake_key"} - ), - ModelLoadBalancingConfiguration( - id='id3', - name='second', - credentials={"openai_api_key": "fake_key"} - ) + ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}), + ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}), + ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}), ] lb_model_manager = LBModelManager( - tenant_id='tenant_id', - provider='openai', + tenant_id="tenant_id", + provider="openai", model_type=ModelType.LLM, - model='gpt-4', + model="gpt-4", load_balancing_configs=load_balancing_configs, - managed_credentials={"openai_api_key": "fake_key"} + managed_credentials={"openai_api_key": "fake_key"}, ) lb_model_manager.cooldown = MagicMock(return_value=None) def is_cooldown(config: ModelLoadBalancingConfiguration): - if config.id == 'id1': + if config.id == "id1": return True return False @@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager): assert lb_model_manager.in_cooldown(config3) is False start_index = 0 + def incr(key): nonlocal start_index start_index += 1 return start_index - mocker.patch('redis.Redis.incr', side_effect=incr) - mocker.patch('redis.Redis.set', return_value=None) - mocker.patch('redis.Redis.expire', return_value=None) + mocker.patch("redis.Redis.incr", side_effect=incr) + mocker.patch("redis.Redis.set", return_value=None) + mocker.patch("redis.Redis.expire", return_value=None) config = lb_model_manager.fetch_next() assert config == config2 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 072b6f100f3a16..2f4214a5801de1 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -11,62 +11,62 @@ def test__to_model_settings(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ), LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True - ) + enabled=True, + ), ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 2 - assert result[0].load_balancing_configs[0].name == '__inherit__' - assert result[0].load_balancing_configs[1].name == 'first' + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" def test__to_model_settings_only_one_lb(mocker): @@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ) ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 @@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=False - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ), LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True - ) + enabled=True, + ), ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py index 9addeeadca7bf9..279a6cdbc328e6 100644 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -5,52 +5,52 @@ def test_get_parameter_type(): - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" with pytest.raises(ValueError): - ToolParameterConverter.get_parameter_type('unsupported_type') + ToolParameterConverter.get_parameter_type("unsupported_type") def test_cast_parameter_by_type(): # string - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" # secret input - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" # select - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" # boolean - true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something'] + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] for value in true_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True - false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, ''] + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] for value in false_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False # number - assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None # unknown - assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1' - assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" + assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 3a32829e373c28..8020674ee6e3d9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db @@ -11,29 +11,30 @@ def test_execute_answer(): node = AnswerNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'answer', - 'data': { - 'title': '123', - 'type': 'answer', - 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' - } - } + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'weather'], 'sunny') - pool.add(['llm', 'text'], 'You are a helpful AI.') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") # Mock db.session.close() db.session.close = MagicMock() @@ -42,4 +43,4 @@ def test_execute_answer(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 4662c5ff2b26d8..9535bc2186af72 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db @@ -11,134 +11,81 @@ def test_execute_if_else_result_true(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'and', - 'conditions': [ - { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'not_contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'start with', - 'variable_selector': ['start', 'start_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'end with', - 'variable_selector': ['start', 'end_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is', - 'variable_selector': ['start', 'is'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is not', - 'variable_selector': ['start', 'is_not'], - 'value': 'ab' - }, - { - 'comparison_operator': 'empty', - 'variable_selector': ['start', 'empty'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not empty', - 'variable_selector': ['start', 'not_empty'], - 'value': 'ab' - }, - { - 'comparison_operator': '=', - 'variable_selector': ['start', 'equals'], - 'value': '22' - }, - { - 'comparison_operator': '≠', - 'variable_selector': ['start', 'not_equals'], - 'value': '22' - }, - { - 'comparison_operator': '>', - 'variable_selector': ['start', 'greater_than'], - 'value': '22' - }, - { - 'comparison_operator': '<', - 'variable_selector': ['start', 'less_than'], - 'value': '22' - }, - { - 'comparison_operator': '≥', - 'variable_selector': ['start', 'greater_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': '≤', - 'variable_selector': ['start', 'less_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': 'null', - 'variable_selector': ['start', 'null'] - }, - { - 'comparison_operator': 'not null', - 'variable_selector': ['start', 'not_null'] - }, - ] - } - } + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", + }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, + { + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", + }, + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ac', 'def']) - pool.add(['start', 'contains'], 'cabcde') - pool.add(['start', 'not_contains'], 'zacde') - pool.add(['start', 'start_with'], 'abc') - pool.add(['start', 'end_with'], 'zzab') - pool.add(['start', 'is'], 'ab') - pool.add(['start', 'is_not'], 'aab') - pool.add(['start', 'empty'], '') - pool.add(['start', 'not_empty'], 'aaa') - pool.add(['start', 'equals'], 22) - pool.add(['start', 'not_equals'], 23) - pool.add(['start', 'greater_than'], 23) - pool.add(['start', 'less_than'], 21) - pool.add(['start', 'greater_than_or_equal'], 22) - pool.add(['start', 'less_than_or_equal'], 21) - pool.add(['start', 'not_null'], '1212') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") # Mock db.session.close() db.session.close = MagicMock() @@ -147,46 +94,47 @@ def test_execute_if_else_result_true(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is True + assert result.outputs["result"] is True def test_execute_if_else_result_false(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'or', - 'conditions': [ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", }, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' - } - ] - } - } + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['1ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ab', 'def']) + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) # Mock db.session.close() db.session.close = MagicMock() @@ -195,4 +143,4 @@ def test_execute_if_else_result_false(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is False + assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py new file mode 100644 index 00000000000000..e26c7df642776b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -0,0 +1,150 @@ +from unittest import mock +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.segments import ArrayStringVariable, StringVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode + +DEFAULT_NODE_ID = "node_id" + + +def test_overwrite_string_variable(): + conversation_variable = StringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value="the first value", + ) + + input_variable = StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ) + + node = VariableAssignerNode( + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + node.run(variable_pool) + mock_run.assert_called_once() + + got = variable_pool.get(["conversation", conversation_variable.name]) + assert got is not None + assert got.value == "the second value" + assert got.to_object() == "the second value" + + +def test_append_variable_to_array(): + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value=["the first value"], + ) + + input_variable = StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ) + + node = VariableAssignerNode( + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + node.run(variable_pool) + mock_run.assert_called_once() + + got = variable_pool.get(["conversation", conversation_variable.name]) + assert got is not None + assert got.to_object() == ["the first value", "the second value"] + + +def test_clear_array(): + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value=["the first value"], + ) + + node = VariableAssignerNode( + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + node.run(variable_pool) + + got = variable_pool.get(["conversation", conversation_variable.name]) + assert got is not None + assert got.to_object() == [] diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py index bbc372ed61b65a..21c2f0781d85f9 100644 --- a/api/tests/unit_tests/libs/test_pandas.py +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -3,50 +3,46 @@ def test_pandas_csv(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to csv file - csv_file_path = tmp_path.joinpath('example.csv') + csv_file_path = tmp_path.joinpath("example.csv") df1.to_csv(csv_file_path, index=False) # read from csv file - df2 = pd.read_csv(csv_file_path, on_bad_lines='skip') - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + df2 = pd.read_csv(csv_file_path, on_bad_lines="skip") + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to xlsx file - xlsx_file_path = tmp_path.joinpath('example.xlsx') + xlsx_file_path = tmp_path.joinpath("example.xlsx") df1.to_excel(xlsx_file_path, index=False) # read from xlsx file df2 = pd.read_excel(xlsx_file_path) - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data1 = {'col1': [1, 2, 3, 4, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data1) - data2 = {'col1': [6, 7, 8, 9, 10], - 'col2': ['F', 'G', 'H', 'I', 'J']} + data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]} df2 = pd.DataFrame(data2) # write to xlsx file with sheets - xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx') - sheet1 = 'Sheet1' - sheet2 = 'Sheet2' + xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx") + sheet1 = "Sheet1" + sheet2 = "Sheet2" with pd.ExcelWriter(xlsx_file_path) as excel_writer: df1.to_excel(excel_writer, sheet_name=sheet1, index=False) df2.to_excel(excel_writer, sheet_name=sheet2, index=False) @@ -54,9 +50,9 @@ def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): # read from xlsx file with sheets with pd.ExcelFile(xlsx_file_path) as excel_file: df1 = pd.read_excel(excel_file, sheet_name=sheet1) - assert df1[df1.columns[0]].to_list() == data1['col1'] - assert df1[df1.columns[1]].to_list() == data1['col2'] + assert df1[df1.columns[0]].to_list() == data1["col1"] + assert df1[df1.columns[1]].to_list() == data1["col2"] df2 = pd.read_excel(excel_file, sheet_name=sheet2) - assert df2[df2.columns[0]].to_list() == data2['col1'] - assert df2[df2.columns[1]].to_list() == data2['col2'] + assert df2[df2.columns[0]].to_list() == data2["col1"] + assert df2[df2.columns[1]].to_list() == data2["col2"] diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py index a979b77d70a285..2dc51252f00e72 100644 --- a/api/tests/unit_tests/libs/test_rsa.py +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -15,7 +15,7 @@ def test_gmpy2_pkcs10aep_cipher() -> None: private_rsa_key = RSA.import_key(private_key) private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key) - raw_text = 'raw_text' + raw_text = "raw_text" raw_text_bytes = raw_text.encode() # RSA encryption by public key and decryption by private key diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py index 75a534412673b1..b9aee4af5f31c7 100644 --- a/api/tests/unit_tests/libs/test_yarl.py +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -3,21 +3,21 @@ def test_yarl_urls(): - expected_1 = 'https://dify.ai/api' - assert str(URL('https://dify.ai') / 'api') == expected_1 - assert str(URL('https://dify.ai/') / 'api') == expected_1 + expected_1 = "https://dify.ai/api" + assert str(URL("https://dify.ai") / "api") == expected_1 + assert str(URL("https://dify.ai/") / "api") == expected_1 - expected_2 = 'http://dify.ai:12345/api' - assert str(URL('http://dify.ai:12345') / 'api') == expected_2 - assert str(URL('http://dify.ai:12345/') / 'api') == expected_2 + expected_2 = "http://dify.ai:12345/api" + assert str(URL("http://dify.ai:12345") / "api") == expected_2 + assert str(URL("http://dify.ai:12345/") / "api") == expected_2 - expected_3 = 'https://dify.ai/api/v1' - assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3 - assert str(URL('https://dify.ai') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/api') / 'v1') == expected_3 - assert str(URL('https://dify.ai/api/') / 'v1') == expected_3 + expected_3 = "https://dify.ai/api/v1" + assert str(URL("https://dify.ai") / "api" / "v1") == expected_3 + assert str(URL("https://dify.ai") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/api") / "v1") == expected_3 + assert str(URL("https://dify.ai/api/") / "v1") == expected_3 with pytest.raises(ValueError) as e1: - str(URL('https://dify.ai') / '/api') + str(URL("https://dify.ai") / "/api") assert str(e1.value) == "Appending path '/api' starting from slash is forbidden" diff --git a/api/tests/unit_tests/models/test_account.py b/api/tests/unit_tests/models/test_account.py index 006b99fb7d0935..026912ffbed300 100644 --- a/api/tests/unit_tests/models/test_account.py +++ b/api/tests/unit_tests/models/test_account.py @@ -2,13 +2,13 @@ def test_account_is_privileged_role() -> None: - assert TenantAccountRole.ADMIN == 'admin' - assert TenantAccountRole.OWNER == 'owner' - assert TenantAccountRole.EDITOR == 'editor' - assert TenantAccountRole.NORMAL == 'normal' + assert TenantAccountRole.ADMIN == "admin" + assert TenantAccountRole.OWNER == "owner" + assert TenantAccountRole.EDITOR == "editor" + assert TenantAccountRole.NORMAL == "normal" assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN) assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR) - assert not TenantAccountRole.is_privileged_role('') + assert not TenantAccountRole.is_privileged_role("") diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py new file mode 100644 index 00000000000000..7968347decbdda --- /dev/null +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -0,0 +1,25 @@ +from uuid import uuid4 + +from core.app.segments import SegmentType, factory +from models import ConversationVariable + + +def test_from_variable_and_to_variable(): + variable = factory.build_variable_from_mapping( + { + "id": str(uuid4()), + "name": "name", + "value_type": SegmentType.OBJECT, + "value": { + "key": { + "key": "value", + } + }, + } + ) + + conversation_variable = ConversationVariable.from_variable( + app_id="app_id", conversation_id="conversation_id", variable=variable + ) + + assert conversation_variable.to_variable() == variable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index facea34b5b4c9c..40483d7e3a3baa 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -8,20 +8,30 @@ def test_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance - workflow = Workflow() + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance variables = [variable1, variable2, variable3, variable4] @@ -32,20 +42,30 @@ def test_environment_variables(): def test_update_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance - workflow = Workflow() + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): variables = [variable1, variable2, variable3, variable4] @@ -56,40 +76,48 @@ def test_update_environment_variables(): # Update the name of variable3 and keep the value as it is variables[2] = variable3.model_copy( update={ - 'name': 'new name', - 'value': HIDDEN_VALUE, + "name": "new name", + "value": HIDDEN_VALUE, } ) workflow.environment_variables = variables - assert workflow.environment_variables[2].name == 'new name' + assert workflow.environment_variables[2].name == "new name" assert workflow.environment_variables[2].value == variable3.value def test_to_dict(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance - workflow = Workflow() - workflow.graph = '{}' - workflow.features = '{}' + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance workflow.environment_variables = [ - SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}), - StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}), + SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}), + StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}), ] workflow_dict = workflow.to_dict() - assert workflow_dict['environment_variables'][0]['value'] == '' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "" + assert workflow_dict["environment_variables"][1]["value"] == "text" workflow_dict = workflow.to_dict(include_secret=True) - assert workflow_dict['environment_variables'][0]['value'] == 'secret' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "secret" + assert workflow_dict["environment_variables"][1]["value"] == "text" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index f589cd20971650..805d92dfc93c1c 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ ModelConfigEntity, PromptTemplateEntity, VariableEntity, + VariableEntityType, ) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode @@ -25,23 +26,24 @@ @pytest.fixture def default_variables(): - return [ + value = [ VariableEntity( variable="text_input", label="text-input", - type=VariableEntity.Type.TEXT_INPUT + type=VariableEntityType.TEXT_INPUT, ), VariableEntity( variable="paragraph", label="paragraph", - type=VariableEntity.Type.PARAGRAPH + type=VariableEntityType.PARAGRAPH, ), VariableEntity( variable="select", label="select", - type=VariableEntity.Type.SELECT - ) + type=VariableEntityType.SELECT, + ), ] + return value def test__convert_to_start_node(default_variables): @@ -81,18 +83,12 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -103,10 +99,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -151,18 +144,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -173,10 +160,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -205,37 +189,25 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["sys", "query"] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -249,37 +221,25 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -291,14 +251,12 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -306,7 +264,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -314,17 +272,17 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): @@ -335,14 +293,12 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -350,7 +306,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -358,17 +314,17 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): @@ -379,14 +335,12 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -394,12 +348,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ]) + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -407,18 +365,18 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], list) assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) template = prompt_template.advanced_chat_prompt_template.messages[0].text for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): @@ -429,14 +387,12 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -446,12 +402,9 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var prompt_type=PromptTemplateEntity.PromptType.ADVANCED, advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" - "Human: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) - ) + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -459,14 +412,14 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], dict) template = prompt_template.advanced_completion_prompt_template.prompt for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 22373199043d6c..29558a93c242a8 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -2,54 +2,122 @@ import pytest -from core.helper.position_helper import get_position_map +from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map @pytest.fixture def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions.yaml").write_text( + dedent( + """\ - first - second # - commented - third - + - 9999999999999 - forth - """)) + """ + ) + ) return str(tmp_path) @pytest.fixture def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions_all_commented.yaml").write_text( + dedent( + """\ # - commented1 # - commented2 - - - - - - """)) + - + - + + """ + ) + ) return str(tmp_path) def test_position_helper(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml') + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") assert len(position_map) == 4 assert position_map == { - 'first': 0, - 'second': 1, - 'third': 2, - 'forth': 3, + "first": 0, + "second": 1, + "third": 2, + "forth": 3, } def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): position_map = get_position_map( - folder_path=prepare_empty_commented_positions_yaml, - file_name='example_positions_all_commented.yaml') + folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml" + ) assert position_map == {} + + +def test_excluded_position_data(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = set() + exclude_set = {"9999999999999"} + + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"] + + +def test_included_position_data(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = {"forth", "first"} + exclude_set = {} + + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ["forth", "first"] diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index c0452b4e4d803a..95b93651d57f80 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -5,17 +5,18 @@ from core.tools.utils.yaml_utils import load_yaml_file -EXAMPLE_YAML_FILE = 'example_yaml.yaml' -INVALID_YAML_FILE = 'invalid_yaml.yaml' -NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' +EXAMPLE_YAML_FILE = "example_yaml.yaml" +INVALID_YAML_FILE = "invalid_yaml.yaml" +NON_EXISTING_YAML_FILE = "non_existing_file.yaml" @pytest.fixture def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -26,7 +27,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: - Java - C++ empty_key: - """)) + """ + ) + ) return str(file_path) @@ -34,8 +37,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(INVALID_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -45,13 +49,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: - Python - Java - C++ - """)) + """ + ) + ) return str(file_path) def test_load_yaml_non_existing_file(): assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path='') == {} + assert load_yaml_file(file_path="") == {} with pytest.raises(FileNotFoundError): load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) @@ -60,12 +66,12 @@ def test_load_yaml_non_existing_file(): def test_load_valid_yaml_file(prepare_example_yaml_file): yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 - assert yaml_data['age'] == 30 - assert yaml_data['gender'] == 'male' - assert yaml_data['address']['city'] == 'Example City' - assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} - assert yaml_data.get('empty_key') is None - assert yaml_data.get('non_existed_key') is None + assert yaml_data["age"] == 30 + assert yaml_data["gender"] == "male" + assert yaml_data["address"]["city"] == "Example City" + assert set(yaml_data["languages"]) == {"Python", "Java", "C++"} + assert yaml_data.get("empty_key") is None + assert yaml_data.get("non_existed_key") is None def test_load_invalid_yaml_file(prepare_invalid_yaml_file): diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index c954c528fb2499..0b23200dc33b0e 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ + api/tests/integration_tests/vdb/elasticsearch \ api/tests/integration_tests/vdb/test_vector_store.py \ No newline at end of file diff --git a/dev/reformat b/dev/reformat index f50ccb04c44ed1..ad83e897d978bd 100755 --- a/dev/reformat +++ b/dev/reformat @@ -11,5 +11,8 @@ fi # run ruff linter ruff check --fix ./api +# run ruff formatter +ruff format ./api + # run dotenv-linter linter dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 9c2efdd817175f..1f23dc84e0c901 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.6.15 + image: langgenius/dify-api:0.7.2 restart: always environment: # Startup mode, 'api' starts the API server. @@ -169,6 +169,11 @@ services: CHROMA_DATABASE: default_database CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_CREDENTIALS: xxxxxx + # ElasticSearch Config + ELASTICSEARCH_HOST: 127.0.0.1 + ELASTICSEARCH_PORT: 9200 + ELASTICSEARCH_USERNAME: elastic + ELASTICSEARCH_PASSWORD: elastic # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified @@ -224,7 +229,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.15 + image: langgenius/dify-api:0.7.2 restart: always environment: CONSOLE_WEB_URL: '' @@ -371,6 +376,11 @@ services: CHROMA_DATABASE: default_database CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_CREDENTIALS: xxxxxx + # ElasticSearch Config + ELASTICSEARCH_HOST: 127.0.0.1 + ELASTICSEARCH_PORT: 9200 + ELASTICSEARCH_USERNAME: elastic + ELASTICSEARCH_PASSWORD: elastic # Notion import configuration, support public and internal NOTION_INTEGRATION_TYPE: public NOTION_CLIENT_SECRET: you-client-secret @@ -390,7 +400,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.15 + image: langgenius/dify-web:0.7.2 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index c463bf1bec8f41..d8fa14f7c0ad5e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -173,6 +173,36 @@ SQLALCHEMY_POOL_RECYCLE=3600 # Whether to print SQL, default is false. SQLALCHEMY_ECHO=false +# Maximum number of connections to the database +# Default is 100 +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-MAX-CONNECTIONS +POSTGRES_MAX_CONNECTIONS=100 + +# Sets the amount of shared memory used for postgres's shared buffers. +# Default is 128MB +# Recommended value: 25% of available memory +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-SHARED-BUFFERS +POSTGRES_SHARED_BUFFERS=128MB + +# Sets the amount of memory used by each database worker for working space. +# Default is 4MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-WORK-MEM +POSTGRES_WORK_MEM=4MB + +# Sets the amount of memory reserved for maintenance activities. +# Default is 64MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-MAINTENANCE-WORK-MEM +POSTGRES_MAINTENANCE_WORK_MEM=64MB + +# Sets the planner's assumption about the effective cache size. +# Default is 4096MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE +POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB + # ------------------------------ # Redis Configuration # This Redis configuration is used for caching and for pub/sub during conversation. @@ -255,6 +285,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=https://oss-ap-southeast-1-internal.aliyuncs.com ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Tencent COS Configuration # The name of the Tencent COS bucket to use for storing files. @@ -273,7 +305,7 @@ TENCENT_COS_SCHEME=your-scheme # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`. VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -366,6 +398,12 @@ TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` +ELASTICSEARCH_HOST=0.0.0.0 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -478,6 +516,8 @@ RESET_PASSWORD_TOKEN_EXPIRY_HOURS=24 CODE_EXECUTION_ENDPOINT=http://sandbox:8194 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_DEPTH=5 +CODE_MAX_PRECISION=20 CODE_MAX_STRING_LENGTH=80000 TEMPLATE_TRANSFORM_MAX_LENGTH=80000 CODE_MAX_STRING_ARRAY_LENGTH=30 @@ -627,6 +667,7 @@ NGINX_KEEPALIVE_TIMEOUT=65 NGINX_PROXY_READ_TIMEOUT=3600s NGINX_PROXY_SEND_TIMEOUT=3600s +# Set true to accept requests for /.well-known/acme-challenge/ NGINX_ENABLE_CERTBOT_CHALLENGE=false # ------------------------------ @@ -664,3 +705,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} # ------------------------------ EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= \ No newline at end of file diff --git a/docker/certbot/README.md b/docker/certbot/README.md index 3fab2f4bb79a68..c6f73ae699df15 100644 --- a/docker/certbot/README.md +++ b/docker/certbot/README.md @@ -16,7 +16,7 @@ Use `docker-compose --profile certbot up` to use this features. CERTBOT_DOMAIN=your_domain.com CERTBOT_EMAIL=example@your_domain.com ``` - excecute command: + execute command: ```shell sudo docker network prune sudo docker-compose --profile certbot up --force-recreate -d @@ -30,7 +30,7 @@ Use `docker-compose --profile certbot up` to use this features. ```properties NGINX_HTTPS_ENABLED=true ``` - excecute command: + execute command: ```shell sudo docker-compose --profile certbot up -d --no-deps --force-recreate nginx ``` @@ -73,4 +73,4 @@ To use cert files dir `nginx/ssl` as before, simply launch containers WITHOUT `- ```shell sudo docker-compose up -d -``` \ No newline at end of file +``` diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 6ab003ceabe5d5..9a4c40448bf112 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -9,6 +9,12 @@ services: POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_DB: ${POSTGRES_DB:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} + command: > + postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' + -c 'shared_buffers=${POSTGRES_SHARED_BUFFERS:-128MB}' + -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' + -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' + -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' volumes: - ./volumes/db/data:/var/lib/postgresql/data ports: @@ -28,7 +34,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.1 + image: langgenius/dify-sandbox:0.2.6 restart: always environment: # The DifySandbox configurations diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1caf244fb1e213..7a68bb38750067 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,5 +1,6 @@ x-shared-env: &shared-api-worker-env LOG_LEVEL: ${LOG_LEVEL:-INFO} + LOG_FILE: ${LOG_FILE:-} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} @@ -65,6 +66,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} + ALIYUN_OSS_PATHS: ${ALIYUN_OSS_PATH:-} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} @@ -120,6 +122,11 @@ x-shared-env: &shared-api-worker-env CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} + ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} + ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} + ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} + ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} @@ -171,6 +178,8 @@ x-shared-env: &shared-api-worker-env CODE_EXECUTION_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} + CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} + CODE_MAX_PRECISION: ${CODE_MAX_PRECISION:-20} CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} @@ -182,7 +191,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.6.15 + image: langgenius/dify-api:0.7.2 restart: always environment: # Use the shared environment variables. @@ -202,7 +211,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.15 + image: langgenius/dify-api:0.7.2 restart: always environment: # Use the shared environment variables. @@ -221,12 +230,13 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.15 + image: langgenius/dify-web:0.7.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} APP_API_URL: ${APP_API_URL:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} + NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} # The postgres database. db: @@ -237,6 +247,12 @@ services: POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_DB: ${POSTGRES_DB:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} + command: > + postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' + -c 'shared_buffers=${POSTGRES_SHARED_BUFFERS:-128MB}' + -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' + -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' + -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: @@ -259,7 +275,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.1 + image: langgenius/dify-sandbox:0.2.6 restart: always environment: # The DifySandbox configurations @@ -279,7 +295,7 @@ services: # ssrf_proxy server # for more information, please refer to - # https://docs.dify.ai/learn-more/faq/self-host-faq#id-18.-why-is-ssrf_proxy-needed + # https://docs.dify.ai/learn-more/faq/install-faq#id-18.-why-is-ssrf_proxy-needed ssrf_proxy: image: ubuntu/squid:latest restart: always @@ -571,7 +587,7 @@ services: # MyScale vector database myscale: container_name: myscale - image: myscale/myscaledb:1.6 + image: myscale/myscaledb:1.6.4 profiles: - myscale restart: always @@ -583,6 +599,61 @@ services: ports: - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}" + # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html + # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3 + container_name: elasticsearch + profiles: + - elasticsearch + restart: always + volumes: + - dify_es01_data:/usr/share/elasticsearch/data + environment: + - ELASTIC_PASSWORD=${ELASTICSEARCH_PASSWORD:-elastic} + - cluster.name=dify-es-cluster + - node.name=dify-es0 + - discovery.type=single-node + - xpack.license.self_generated.type=trial + - xpack.security.enabled=true + - xpack.security.enrollment.enabled=false + - xpack.security.http.ssl.enabled=false + ports: + - ${ELASTICSEARCH_PORT:-9200}:9200 + healthcheck: + test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"] + interval: 30s + timeout: 10s + retries: 50 + + # https://www.elastic.co/guide/en/kibana/current/docker.html + # https://www.elastic.co/guide/en/kibana/current/settings.html + kibana: + image: docker.elastic.co/kibana/kibana:8.14.3 + container_name: kibana + profiles: + - elasticsearch + depends_on: + - elasticsearch + restart: always + environment: + - XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY=d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa + - NO_PROXY=localhost,127.0.0.1,elasticsearch,kibana + - XPACK_SECURITY_ENABLED=true + - XPACK_SECURITY_ENROLLMENT_ENABLED=false + - XPACK_SECURITY_HTTP_SSL_ENABLED=false + - XPACK_FLEET_ISAIRGAPPED=true + - I18N_LOCALE=zh-CN + - SERVER_PORT=5601 + - ELASTICSEARCH_HOSTS="http://elasticsearch:9200" + ports: + - ${KIBANA_PORT:-5601}:5601 + healthcheck: + test: [ "CMD-SHELL", "curl -s http://localhost:5601 >/dev/null || exit 1" ] + interval: 30s + timeout: 10s + retries: 3 + # unstructured . # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) unstructured: @@ -606,3 +677,4 @@ networks: volumes: oradata: + dify_es01_data: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index 750dcfe9502f11..04d0fb5ed30cf5 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -9,6 +9,35 @@ POSTGRES_DB=dify # postgres data directory PGDATA=/var/lib/postgresql/data/pgdata +# Maximum number of connections to the database +# Default is 100 +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-MAX-CONNECTIONS +POSTGRES_MAX_CONNECTIONS=100 + +# Sets the amount of shared memory used for postgres's shared buffers. +# Default is 128MB +# Recommended value: 25% of available memory +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-SHARED-BUFFERS +POSTGRES_SHARED_BUFFERS=128MB + +# Sets the amount of memory used by each database worker for working space. +# Default is 4MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-WORK-MEM +POSTGRES_WORK_MEM=4MB + +# Sets the amount of memory reserved for maintenance activities. +# Default is 64MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-MAINTENANCE-WORK-MEM +POSTGRES_MAINTENANCE_WORK_MEM=64MB + +# Sets the planner's assumption about the effective cache size. +# Default is 4096MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE +POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB # ------------------------------ # Environment Variables for sandbox Service diff --git a/sdks/nodejs-client/index.d.ts b/sdks/nodejs-client/index.d.ts index 7fdd943f637d82..a2e6e50aef315c 100644 --- a/sdks/nodejs-client/index.d.ts +++ b/sdks/nodejs-client/index.d.ts @@ -33,6 +33,10 @@ export declare class DifyClient { getApplicationParameters(user: User): Promise; fileUpload(data: FormData): Promise; + + textToAudio(text: string ,user: string, streaming?: boolean): Promise; + + getMeta(user: User): Promise; } export declare class CompletionClient extends DifyClient { @@ -54,6 +58,18 @@ export declare class ChatClient extends DifyClient { files?: File[] | null ): Promise; + getSuggested(message_id: string, user: User): Promise; + + stopMessage(task_id: string, user: User) : Promise; + + + getConversations( + user: User, + first_id?: string | null, + limit?: number | null, + pinned?: boolean | null + ): Promise; + getConversationMessages( user: User, conversation_id?: string, @@ -61,9 +77,15 @@ export declare class ChatClient extends DifyClient { limit?: number | null ): Promise; - getConversations(user: User, first_id?: string | null, limit?: number | null, pinned?: boolean | null): Promise; - - renameConversation(conversation_id: string, name: string, user: User): Promise; + renameConversation(conversation_id: string, name: string, user: User,auto_generate:boolean): Promise; deleteConversation(conversation_id: string, user: User): Promise; + + audioToText(data: FormData): Promise; +} + +export declare class WorkflowClient extends DifyClient { + run(inputs: any, user: User, stream?: boolean,): Promise; + + stop(task_id: string, user: User): Promise; } \ No newline at end of file diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index 93491fa16b711e..858241ce5aef4b 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -2,30 +2,55 @@ import axios from "axios"; export const BASE_URL = "https://api.dify.ai/v1"; export const routes = { + // app's + feedback: { + method: "POST", + url: (message_id) => `/messages/${message_id}/feedbacks`, + }, application: { method: "GET", url: () => `/parameters`, }, - feedback: { + fileUpload: { method: "POST", - url: (message_id) => `/messages/${message_id}/feedbacks`, + url: () => `/files/upload`, + }, + textToAudio: { + method: "POST", + url: () => `/text-to-audio`, }, + getMeta: { + method: "GET", + url: () => `/meta`, + }, + + // completion's createCompletionMessage: { method: "POST", url: () => `/completion-messages`, }, + + // chat's createChatMessage: { method: "POST", url: () => `/chat-messages`, }, - getConversationMessages: { + getSuggested:{ method: "GET", - url: () => `/messages`, + url: (message_id) => `/messages/${message_id}/suggested`, + }, + stopChatMessage: { + method: "POST", + url: (task_id) => `/chat-messages/${task_id}/stop`, }, getConversations: { method: "GET", url: () => `/conversations`, }, + getConversationMessages: { + method: "GET", + url: () => `/messages`, + }, renameConversation: { method: "POST", url: (conversation_id) => `/conversations/${conversation_id}/name`, @@ -34,14 +59,21 @@ export const routes = { method: "DELETE", url: (conversation_id) => `/conversations/${conversation_id}`, }, - fileUpload: { + audioToText: { method: "POST", - url: () => `/files/upload`, + url: () => `/audio-to-text`, }, + + // workflow‘s runWorkflow: { method: "POST", url: () => `/workflows/run`, }, + stopWorkflow: { + method: "POST", + url: (task_id) => `/workflows/${task_id}/stop`, + } + }; export class DifyClient { @@ -129,6 +161,31 @@ export class DifyClient { } ); } + + textToAudio(text, user, streaming = false) { + const data = { + text, + user, + streaming + }; + return this.sendRequest( + routes.textToAudio.method, + routes.textToAudio.url(), + data, + null, + streaming + ); + } + + getMeta(user) { + const params = { user }; + return this.sendRequest( + routes.meta.method, + routes.meta.url(), + null, + params + ); + } } export class CompletionClient extends DifyClient { @@ -191,6 +248,34 @@ export class ChatClient extends DifyClient { ); } + getSuggested(message_id, user) { + const data = { user }; + return this.sendRequest( + routes.getSuggested.method, + routes.getSuggested.url(message_id), + data + ); + } + + stopMessage(task_id, user) { + const data = { user }; + return this.sendRequest( + routes.stopChatMessage.method, + routes.stopChatMessage.url(task_id), + data + ); + } + + getConversations(user, first_id = null, limit = null, pinned = null) { + const params = { user, first_id: first_id, limit, pinned }; + return this.sendRequest( + routes.getConversations.method, + routes.getConversations.url(), + null, + params + ); + } + getConversationMessages( user, conversation_id = "", @@ -213,16 +298,6 @@ export class ChatClient extends DifyClient { ); } - getConversations(user, first_id = null, limit = null, pinned = null) { - const params = { user, first_id: first_id, limit, pinned }; - return this.sendRequest( - routes.getConversations.method, - routes.getConversations.url(), - null, - params - ); - } - renameConversation(conversation_id, name, user, auto_generate) { const data = { name, user, auto_generate }; return this.sendRequest( @@ -240,4 +315,46 @@ export class ChatClient extends DifyClient { data ); } + + + audioToText(data) { + return this.sendRequest( + routes.audioToText.method, + routes.audioToText.url(), + data, + null, + false, + { + "Content-Type": 'multipart/form-data' + } + ); + } + +} + +export class WorkflowClient extends DifyClient { + run(inputs,user,stream) { + const data = { + inputs, + response_mode: stream ? "streaming" : "blocking", + user + }; + + return this.sendRequest( + routes.runWorkflow.method, + routes.runWorkflow.url(), + data, + null, + stream + ); + } + + stop(task_id, user) { + const data = { user }; + return this.sendRequest( + routes.stopWorkflow.method, + routes.stopWorkflow.url(task_id), + data + ); + } } \ No newline at end of file diff --git a/sdks/php-client/dify-client.php b/sdks/php-client/dify-client.php index c5ddcc3a87ffbe..ccd61f091a9f3b 100644 --- a/sdks/php-client/dify-client.php +++ b/sdks/php-client/dify-client.php @@ -9,9 +9,9 @@ class DifyClient { protected $base_url; protected $client; - public function __construct($api_key) { + public function __construct($api_key, $base_url = null) { $this->api_key = $api_key; - $this->base_url = "https://api.dify.ai/v1/"; + $this->base_url = $base_url ?? "https://api.dify.ai/v1/"; $this->client = new Client([ 'base_uri' => $this->base_url, 'headers' => [ @@ -80,6 +80,25 @@ protected function prepareMultipart($data, $files) { return $multipart; } + + + public function text_to_audio($text, $user, $streaming = false) { + $data = [ + 'text' => $text, + 'user' => $user, + 'streaming' => $streaming + ]; + + return $this->send_request('POST', 'text-to-audio', $data); + } + + public function get_meta($user) { + $params = [ + 'user' => $user + ]; + + return $this->send_request('GET', 'meta',null, $params); + } } class CompletionClient extends DifyClient { @@ -109,6 +128,28 @@ public function create_chat_message($inputs, $query, $user, $response_mode = 'bl return $this->send_request('POST', 'chat-messages', $data, null, $response_mode === 'streaming'); } + + public function get_suggestions($message_id, $user) { + $params = [ + 'user' => $user + ] + return $this->send_request('GET', "messages/{$message_id}/suggested", null, $params); + } + + public function stop_message($task_id, $user) { + $data = ['user' => $user]; + return $this->send_request('POST', "chat-messages/{$task_id}/stop", $data); + } + + public function get_conversations($user, $first_id = null, $limit = null, $pinned = null) { + $params = [ + 'user' => $user, + 'first_id' => $first_id, + 'limit' => $limit, + 'pinned'=> $pinned, + ]; + return $this->send_request('GET', 'conversations', null, $params); + } public function get_conversation_messages($user, $conversation_id = null, $first_id = null, $limit = null) { $params = ['user' => $user]; @@ -125,22 +166,51 @@ public function get_conversation_messages($user, $conversation_id = null, $first return $this->send_request('GET', 'messages', null, $params); } + + public function rename_conversation($conversation_id, $name,$auto_generate, $user) { + $data = [ + 'name' => $name, + 'user' => $user, + 'auto_generate' => $auto_generate + ]; + return $this->send_request('PATCH', "conversations/{$conversation_id}", $data); + } - public function get_conversations($user, $first_id = null, $limit = null, $pinned = null) { - $params = [ + public function delete_conversation($conversation_id, $user) { + $data = [ 'user' => $user, - 'first_id' => $first_id, - 'limit' => $limit, - 'pinned'=> $pinned, ]; - return $this->send_request('GET', 'conversations', null, $params); + return $this->send_request('DELETE', "conversations/{$conversation_id}", $data); } - public function rename_conversation($conversation_id, $name, $user) { + public function audio_to_text($audio_file, $user) { $data = [ - 'name' => $name, 'user' => $user, ]; - return $this->send_request('PATCH', "conversations/{$conversation_id}", $data); + $options = [ + 'multipart' => $this->prepareMultipart($data, $files) + ]; + return $this->file_client->request('POST', 'audio-to-text', $options); + } + } + +class WorkflowClient extends DifyClient{ + public function run($inputs, $response_mode, $user) { + $data = [ + 'inputs' => $inputs, + 'response_mode' => $response_mode, + 'user' => $user, + ]; + return $this->send_request('POST', 'workflows/run', $data); + } + + public function stop($task_id, $user) { + $data = [ + 'user' => $user, + ]; + return $this->send_request('POST', "workflows/tasks/{$task_id}/stop",$data); + } + +} \ No newline at end of file diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index 22d4e6189a5a42..a8da1d7cae7ab3 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -16,6 +16,7 @@ def _send_request(self, method, endpoint, json=None, params=None, stream=False): response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream) return response + def _send_request_with_files(self, method, endpoint, data, files): headers = { @@ -26,24 +27,36 @@ def _send_request_with_files(self, method, endpoint, data, files): response = requests.request(method, url, data=data, headers=headers, files=files) return response - + def message_feedback(self, message_id, rating, user): data = { "rating": rating, "user": user } return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - + def get_application_parameters(self, user): params = {"user": user} return self._send_request("GET", "/parameters", params=params) - + def file_upload(self, user, files): data = { "user": user } return self._send_request_with_files("POST", "/files/upload", data=data, files=files) + def text_to_audio(self, text:str, user:str, streaming:bool=False): + data = { + "text": text, + "user": user, + "streaming": streaming + } + return self._send_request("POST", "/text-to-audio", data=data) + + def get_meta(self,user): + params = { "user": user} + return self._send_request("GET", f"/meta", params=params) + class CompletionClient(DifyClient): def create_completion_message(self, inputs, response_mode, user, files=None): @@ -71,7 +84,19 @@ def create_chat_message(self, inputs, query, user, response_mode="blocking", con return self._send_request("POST", "/chat-messages", data, stream=True if response_mode == "streaming" else False) + + def get_suggested(self, message_id, user:str): + params = {"user": user} + return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) + + def stop_message(self, task_id, user): + data = {"user": user} + return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) + def get_conversations(self, user, last_id=None, limit=None, pinned=None): + params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} + return self._send_request("GET", "/conversations", params=params) + def get_conversation_messages(self, user, conversation_id=None, first_id=None, limit=None): params = {"user": user} @@ -83,11 +108,26 @@ def get_conversation_messages(self, user, conversation_id=None, first_id=None, l params["limit"] = limit return self._send_request("GET", "/messages", params=params) - - def get_conversations(self, user, last_id=None, limit=None, pinned=None): - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return self._send_request("GET", "/conversations", params=params) - - def rename_conversation(self, conversation_id, name, user): - data = {"name": name, "user": user} + + def rename_conversation(self, conversation_id, name,auto_generate:bool, user:str): + data = {"name": name, "auto_generate": auto_generate,"user": user} return self._send_request("POST", f"/conversations/{conversation_id}/name", data) + + def delete_conversation(self, conversation_id, user): + data = {"user": user} + return self._send_request("DELETE", f"/conversations/{conversation_id}", data) + + def audio_to_text(self, audio_file, user): + data = {"user": user} + files = {"audio_file": audio_file} + return self._send_request_with_files("POST", "/audio-to-text", data, files) + + +class WorkflowClient(DifyClient): + def run(self, inputs:dict, response_mode:str="streaming", user:str="abc-123"): + data = {"inputs": inputs, "response_mode": response_mode, "user": user} + return self._send_request("POST", "/workflows/run", data) + + def stop(self, task_id, user): + data = {"user": user} + return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) diff --git a/web/.env.example b/web/.env.example index 653913033d8d05..3045cef2f9d258 100644 --- a/web/.env.example +++ b/web/.env.example @@ -13,3 +13,9 @@ NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api # SENTRY NEXT_PUBLIC_SENTRY_DSN= + +# Disable Next.js Telemetry (https://nextjs.org/telemetry) +NEXT_TELEMETRY_DISABLED=1 + +# Disable Upload Image as WebApp icon default is false +NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false diff --git a/web/Dockerfile b/web/Dockerfile index 56957f0927010f..48bdb2301ad206 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -39,6 +39,7 @@ ENV DEPLOY_ENV=PRODUCTION ENV CONSOLE_API_URL=http://127.0.0.1:5001 ENV APP_API_URL=http://127.0.0.1:5001 ENV PORT=3000 +ENV NEXT_TELEMETRY_DISABLED=1 # set timezone ENV TZ=UTC diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 09569df8bf10c5..8723420d849e19 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -15,13 +15,14 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' +import { useContextSelector } from 'use-context-selector' import s from './style.module.css' import cn from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' -import { fetchAppDetail } from '@/service/apps' -import { useAppContext } from '@/context/app-context' +import { fetchAppDetail, fetchAppSSO } from '@/service/apps' +import AppContext, { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' @@ -52,6 +53,7 @@ const AppDetailLayout: FC = (props) => { icon: NavIcon selectedIcon: NavIcon }>>([]) + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const getNavigations = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: string) => { const navs = [ @@ -114,14 +116,19 @@ const AppDetailLayout: FC = (props) => { router.replace(`/app/${appId}/configuration`) } else { - setAppDetail(res) + setAppDetail({ ...res, enable_sso: false }) setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) + if (systemFeatures.enable_web_sso_switch_component) { + fetchAppSSO({ appId }).then((ssoRes) => { + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + }) + } } }).catch((e: any) => { if (e.status === 404) router.replace('/apps') }) - }, [appId, isCurrentWorkspaceEditor]) + }, [appId, isCurrentWorkspaceEditor, systemFeatures]) useUnmount(() => { setAppDetail() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx index 5fa9a2e40641bd..8f3ee510b86036 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx @@ -2,22 +2,25 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' +import { useContext, useContextSelector } from 'use-context-selector' import AppCard from '@/app/components/app/overview/appCard' import Loading from '@/app/components/base/loading' import { ToastContext } from '@/app/components/base/toast' import { fetchAppDetail, + fetchAppSSO, + updateAppSSO, updateAppSiteAccessToken, updateAppSiteConfig, updateAppSiteStatus, } from '@/service/apps' -import type { App } from '@/types/app' +import type { App, AppSSO } from '@/types/app' import type { UpdateAppSiteCodeResponse } from '@/models/app' import { asyncRunSafe } from '@/utils' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import type { IAppCardProps } from '@/app/components/app/overview/appCard' import { useStore as useAppStore } from '@/app/components/app/store' +import AppContext from '@/context/app-context' export type ICardViewProps = { appId: string @@ -28,11 +31,20 @@ const CardView: FC = ({ appId }) => { const { notify } = useContext(ToastContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const updateAppDetail = async () => { - fetchAppDetail({ url: '/apps', id: appId }).then((res) => { - setAppDetail(res) - }) + try { + const res = await fetchAppDetail({ url: '/apps', id: appId }) + if (systemFeatures.enable_web_sso_switch_component) { + const ssoRes = await fetchAppSSO({ appId }) + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + } + else { + setAppDetail({ ...res }) + } + } + catch (error) { console.error(error) } } const handleCallbackResult = (err: Error | null, message?: string) => { @@ -81,6 +93,16 @@ const CardView: FC = ({ appId }) => { if (!err) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + if (systemFeatures.enable_web_sso_switch_component) { + const [sso_err] = await asyncRunSafe( + updateAppSSO({ id: appId, enabled: Boolean(params.enable_sso) }) as Promise, + ) + if (sso_err) { + handleCallbackResult(sso_err) + return + } + } + handleCallbackResult(err) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx index 74e9140fe50dff..b01bc1b8563f00 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -4,7 +4,7 @@ import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import { useTranslation } from 'react-i18next' import type { PeriodParams } from '@/app/components/app/overview/appChart' -import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/appChart' +import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/appChart' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' import { TIME_PERIOD_LIST } from '@/app/components/app/log/filter' @@ -25,7 +25,7 @@ export default function ChartView({ appId }: IChartViewProps) { const appDetail = useAppStore(state => state.appDetail) const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow' const isWorkflow = appDetail?.mode === 'workflow' - const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }) const onSelect = (item: Item) => { if (item.value === 'all') { @@ -37,7 +37,7 @@ export default function ChartView({ appId }: IChartViewProps) { setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } }) } else { - setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }) } } @@ -79,6 +79,11 @@ export default function ChartView({ appId }: IChartViewProps) { )} + {!isWorkflow && isChatApp && ( +
+ +
+ )} {isWorkflow && (
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 7aa1fca96d8a3f..8e3d8f9ec6583f 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -10,7 +10,7 @@ import { TracingProvider } from './type' import ProviderConfigModal from './provider-config-modal' import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -85,6 +85,7 @@ const ConfigPopup: FC = ({ = ({ = ({ <> {providerAllNotConfigured ? ( - {switchContent} - - + ) : switchContent} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 88c37d0b125f72..bc724c1449af64 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -117,7 +117,6 @@ const Panel: FC = () => { // eslint-disable-next-line react-hooks/exhaustive-deps }, []) - const [isFold, setFold] = useState(false) const [controlShowPopup, setControlShowPopup] = useState(0) const showPopup = useCallback(() => { setControlShowPopup(Date.now()) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index 2411d2baa489b9..e7ecd2f4ce3d64 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -14,7 +14,7 @@ import { import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' import Button from '@/app/components/base/button' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import ConfirmUi from '@/app/components/base/confirm' +import Confirm from '@/app/components/base/confirm' import { addTracingConfig, removeTracingConfig, updateTracingConfig } from '@/service/apps' import Toast from '@/app/components/base/toast' @@ -276,9 +276,8 @@ const ProviderConfigModal: FC = ({ ) : ( - void hasConfigured: boolean onConfig: () => void @@ -29,6 +31,7 @@ const ProviderPanel: FC = ({ type, readOnly, isChosen, + config, onChoose, hasConfigured, onConfig, @@ -41,6 +44,14 @@ const ProviderPanel: FC = ({ onConfig() }, [onConfig]) + const viewBtnClick = useCallback((e: React.MouseEvent) => { + e.preventDefault() + e.stopPropagation() + + const url = `${config?.host}/project/${config?.project_key}` + window.open(url, '_blank', 'noopener,noreferrer') + }, []) + const handleChosen = useCallback((e: React.MouseEvent) => { e.stopPropagation() if (isChosen || !hasConfigured || readOnly) @@ -58,12 +69,20 @@ const ProviderPanel: FC = ({ {isChosen &&
{t(`${I18N_PREFIX}.inUse`)}
}
{!readOnly && ( -
- -
{t(`${I18N_PREFIX}.config`)}
+
+ {hasConfigured && ( +
+ +
{t(`${I18N_PREFIX}.view`)}
+
+ )} +
+ +
{t(`${I18N_PREFIX}.config`)}
+
)} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx index 9119deede879b2..934eb681b9c85e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx @@ -3,7 +3,7 @@ import { ChevronDoubleDownIcon } from '@heroicons/react/20/solid' import type { FC } from 'react' import { useTranslation } from 'react-i18next' import React, { useCallback } from 'react' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -25,9 +25,8 @@ const ToggleFoldBtn: FC = ({ return ( // text-[0px] to hide spacing between tooltip elements
- {isFold && (
@@ -39,7 +38,7 @@ const ToggleFoldBtn: FC = ({
)} -
+
) } diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index 34279f816ebdc0..6dcc39046c4003 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -75,6 +75,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, + icon_type, icon, icon_background, description, @@ -83,6 +84,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { await updateAppInfo({ appID: app.id, name, + icon_type, icon, icon_background, description, @@ -101,11 +103,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } }, [app.id, mutateApps, notify, onRefresh, t]) - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => { + const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { try { const newApp = await copyApp({ appID: app.id, name, + icon_type, icon, icon_background, mode: app.mode, @@ -252,14 +255,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { e.preventDefault() getRedirection(isCurrentWorkspaceEditor, app, push) }} - className='group flex col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' + className='relative group col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' >
{app.mode === 'advanced-chat' && ( @@ -292,17 +297,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
-
- {app.description} +
+
+ {app.description} +
{isCurrentWorkspaceEditor && ( @@ -360,9 +364,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { {showEditModal && ( { {showDuplicateModal && ( setShowDuplicateModal(false)} @@ -392,7 +400,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { title={t('app.deleteAppConfirmTitle')} content={t('app.deleteAppConfirmContent')} isShow={showConfirmDelete} - onClose={() => setShowConfirmDelete(false)} onConfirm={onConfirmDelete} onCancel={() => setShowConfirmDelete(false)} /> diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index c16512bd50db1f..132096c6b455ee 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -139,7 +139,7 @@ const Apps = () => {
- +
{showRenameModal && ( setShowConfirmDelete(false)} onConfirm={onConfirmDelete} onCancel={() => setShowConfirmDelete(false)} /> diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 36395d391de1b3..33451b8a0bdb27 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -236,6 +236,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge name + + Permission + - only_me Only me + - all_team_members All team members + - partial_members Partial members + @@ -243,14 +249,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name", "permission": "only_me"}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST '${apiBaseUrl}/v1/datasets' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "name": "name" + "name": "name", + "permission": "only_me" }' ``` @@ -922,6 +929,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge ID + + Document ID + Document Segment ID @@ -965,6 +975,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge ID + + Document ID + Document Segment ID diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index a624c0594feffb..dc48f92f185495 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -236,6 +236,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库名称 + + 权限 + - only_me 仅自己 + - all_team_members 所有团队成员 + - partial_members 部分团队成员 + @@ -243,14 +249,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name", "permission": "only_me"}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST '${props.apiBaseUrl}/datasets' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "name": "name" + "name": "name", + "permission": "only_me" }' ``` @@ -922,6 +929,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库 ID + + 文档 ID + 文档分段ID @@ -965,6 +975,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库 ID + + 文档 ID + 文档分段ID diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 3f52e9d6f5b7b3..c5bc3bb210b3a8 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -59,6 +59,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, + icon_type, icon, icon_background, description, @@ -69,6 +70,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const app = await updateAppInfo({ appID: appDetail.id, name, + icon_type, icon, icon_background, description, @@ -86,13 +88,14 @@ const AppInfo = ({ expand }: IAppInfoProps) => { } }, [appDetail, mutateApps, notify, setAppDetail, t]) - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => { + const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { if (!appDetail) return try { const newApp = await copyApp({ appID: appDetail.id, name, + icon_type, icon, icon_background, mode: appDetail.mode, @@ -194,7 +197,13 @@ const AppInfo = ({ expand }: IAppInfoProps) => { >
- + { {/* header */}
- + {appDetail.mode === 'advanced-chat' && ( @@ -402,9 +417,11 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {showEditModal && ( { {showDuplicateModal && ( setShowDuplicateModal(false)} @@ -426,7 +445,6 @@ const AppInfo = ({ expand }: IAppInfoProps) => { title={t('app.deleteAppConfirmTitle')} content={t('app.deleteAppConfirmContent')} isShow={showConfirmDelete} - onClose={() => setShowConfirmDelete(false)} onConfirm={onConfirmDelete} onCancel={() => setShowConfirmDelete(false)} /> diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index 09f978b04b4f01..c939cb7bb3a5ac 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -1,15 +1,11 @@ import React from 'react' -import { - InformationCircleIcon, -} from '@heroicons/react/24/outline' -import Tooltip from '../base/tooltip' import AppIcon from '../base/app-icon' -import { randomString } from '@/utils' +import Tooltip from '@/app/components/base/tooltip' export type IAppBasicProps = { iconType?: 'app' | 'api' | 'dataset' | 'webapp' | 'notion' icon?: string - icon_background?: string + icon_background?: string | null name: string type: string | React.ReactNode hoverTip?: string @@ -74,9 +70,17 @@ export default function AppBasic({ icon, icon_background, name, type, hoverTip,
{name} {hoverTip - && - - } + && + {hoverTip} +
+ } + popupClassName='ml-1' + triggerClassName='w-4 h-4 ml-1' + position='top' + /> + }
{type}
} diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx index 63788447de2a89..8e88a4b5f335ad 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx @@ -79,7 +79,7 @@ const EditItem: FC = ({ {!readonly && (
{ + onClick={() => { setIsEdit(true) }} > diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.tsx index 548ecb2be3a971..df80196cf191bb 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next' import EditItem, { EditItemType } from './edit-item' import Drawer from '@/app/components/base/drawer-plus' import { MessageCheckRemove } from '@/app/components/base/icons/src/vender/line/communication' -import DeleteConfirmModal from '@/app/components/base/modal/delete-confirm-modal' +import Confirm from '@/app/components/base/confirm' import { addAnnotation, editAnnotation } from '@/service/annotation' import Toast from '@/app/components/base/toast' import { useProviderContext } from '@/context/provider-context' @@ -85,19 +85,31 @@ const EditAnnotationModal: FC = ({ maxWidthClassName='!max-w-[480px]' title={t('appAnnotation.editModal.title') as string} body={( -
- handleSave(EditItemType.Query, editedContent)} - /> - handleSave(EditItemType.Answer, editedContent)} - /> +
+
+ handleSave(EditItemType.Query, editedContent)} + /> + handleSave(EditItemType.Answer, editedContent)} + /> + setShowModal(false)} + onConfirm={() => { + onRemove() + setShowModal(false) + onHide() + }} + title={t('appDebug.feature.annotation.removeConfirm')} + /> +
)} foot={ @@ -127,16 +139,6 @@ const EditAnnotationModal: FC = ({
} /> - setShowModal(false)} - onRemove={() => { - onRemove() - setShowModal(false) - onHide() - }} - text={t('appDebug.feature.annotation.removeConfirm') as string} - />
) diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 1e65d7a94f4f05..0f54f5bfc32102 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -280,7 +280,7 @@ const Annotation: FC = ({ onSave={async (embeddingModel, score) => { if ( embeddingModel.embedding_model_name !== annotationConfig?.embedding_model?.embedding_model_name - && embeddingModel.embedding_provider_name !== annotationConfig?.embedding_model?.embedding_provider_name + || embeddingModel.embedding_provider_name !== annotationConfig?.embedding_model?.embedding_provider_name ) { const { job_id: jobId }: any = await updateAnnotationStatus(appDetail.id, AnnotationEnableStatus.enable, embeddingModel, score) await ensureJobCompleted(jobId, AnnotationEnableStatus.enable) diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.tsx index a9c9d408366cbe..a6ade49a7973df 100644 --- a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.tsx +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import DeleteConfirmModal from '@/app/components/base/modal/delete-confirm-modal' +import Confirm from '@/app/components/base/confirm' type Props = { isShow: boolean @@ -18,11 +18,11 @@ const RemoveAnnotationConfirmModal: FC = ({ const { t } = useTranslation() return ( - ) } diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index 3abc477d350eca..7b96ad134e6655 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -11,7 +11,7 @@ import HitHistoryNoData from './hit-history-no-data' import cn from '@/utils/classnames' import Drawer from '@/app/components/base/drawer-plus' import { MessageCheckRemove } from '@/app/components/base/icons/src/vender/line/communication' -import DeleteConfirmModal from '@/app/components/base/modal/delete-confirm-modal' +import Confirm from '@/app/components/base/confirm' import TabSlider from '@/app/components/base/tab-slider-plain' import { fetchHitHistoryList } from '@/service/annotation' import { APP_PAGE_LIMIT } from '@/config' @@ -201,8 +201,20 @@ const ViewAnnotationModal: FC = ({ /> } body={( -
- {activeTab === TabType.annotation ? annotationTab : hitHistoryTab} +
+
+ {activeTab === TabType.annotation ? annotationTab : hitHistoryTab} +
+ setShowModal(false)} + onConfirm={async () => { + await onRemove() + setShowModal(false) + onHide() + }} + title={t('appDebug.feature.annotation.removeConfirm')} + />
)} foot={id @@ -220,16 +232,6 @@ const ViewAnnotationModal: FC = ({ ) : undefined} /> - setShowModal(false)} - onRemove={async () => { - await onRemove() - setShowModal(false) - onHide() - }} - text={t('appDebug.feature.annotation.removeConfirm') as string} - />
) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index e971274a71b3bc..0558e299560396 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -24,6 +24,7 @@ import { LeftIndent02 } from '@/app/components/base/icons/src/vender/line/editor import { FileText } from '@/app/components/base/icons/src/vender/line/files' import WorkflowToolConfigureButton from '@/app/components/tools/workflow-tool/configure-button' import type { InputVar } from '@/app/components/workflow/types' +import { appDefaultIconBackground } from '@/config' export type AppPublisherProps = { disabled?: boolean @@ -212,8 +213,8 @@ const AppPublisher = ({ detailNeedUpdate={!!toolPublished && published} workflowAppId={appDetail?.id} icon={{ - content: appDetail?.icon, - background: appDetail?.icon_background, + content: (appDetail.icon_type === 'image' ? '🤖' : appDetail?.icon) || '🤖', + background: (appDetail.icon_type === 'image' ? appDefaultIconBackground : appDetail?.icon_background) || appDefaultIconBackground, }} name={appDetail?.name} description={appDetail?.description} diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index fb3ceadc0d8bcf..a07fe6982152a4 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -9,7 +9,6 @@ import produce from 'immer' import { RiDeleteBinLine, RiErrorWarningFill, - RiQuestionLine, } from '@remixicon/react' import s from './style.module.css' import MessageTypeSelector from './message-type-selector' @@ -174,12 +173,12 @@ const AdvancedPromptInput: FC = ({
{t('appDebug.pageTitle.line1')}
- {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + />
)}
{canDelete && ( diff --git a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx index bfe51379655c0d..f08f2ffc6968e7 100644 --- a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx +++ b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx @@ -24,7 +24,7 @@ const ConfirmAddVar: FC = ({ varNameArr, onConfrim, onCancel, - onHide, + // onHide, }) => { const { t } = useTranslation() const mainContentRef = useRef(null) diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 2811c4240238ec..69e01a8e22478d 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' -import { - RiQuestionLine, -} from '@remixicon/react' import produce from 'immer' import { useContext } from 'use-context-selector' import ConfirmAddVar from './confirm-add-var' @@ -69,7 +66,6 @@ const Prompt: FC = ({ hasSetBlockStatus, showSelectDataSet, externalDataToolsConfig, - isAgent, } = useContext(ConfigContext) const { notify } = useToastContext() const { setShowExternalDataToolModal } = useModalContext() @@ -157,16 +153,16 @@ const Prompt: FC = ({
{mode !== AppType.completion ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
{!readonly && ( - {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + /> )}
- {!isAgent && !readonly && !isMobile && ( + {!readonly && !isMobile && ( )}
diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 20fcf49de18df5..3296c77fb23ae5 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -42,18 +42,19 @@ const ConfigModal: FC = ({ const { type, label, variable, options, max_length } = tempPayload const isStringInput = type === InputVarType.textInput || type === InputVarType.paragraph + const checkVariableName = useCallback((value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + Toast.notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('appDebug.variableConig.varName') }), + }) + return false + } + return true + }, [t]) const handlePayloadChange = useCallback((key: string) => { return (value: any) => { - if (key === 'variable') { - const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) - if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }), - }) - return - } - } setTempPayload((prev) => { const newPayload = { ...prev, @@ -63,19 +64,20 @@ const ConfigModal: FC = ({ return newPayload }) } - }, [t]) + }, []) const handleVarKeyBlur = useCallback((e: any) => { - if (tempPayload.label) + const varName = e.target.value + if (!checkVariableName(varName) || tempPayload.label) return setTempPayload((prev) => { return { ...prev, - label: e.target.value, + label: varName, } }) - }, [tempPayload]) + }, [checkVariableName, tempPayload.label]) const handleConfirm = () => { const moreInfo = tempPayload.variable === payload?.variable @@ -84,10 +86,11 @@ const ConfigModal: FC = ({ type: ChangeType.changeVarName, payload: { beforeKey: payload?.variable || '', afterKey: tempPayload.variable }, } - if (!tempPayload.variable) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.varNameRequired') }) + + const isVariableNameValid = checkVariableName(tempPayload.variable) + if (!isVariableNameValid) return - } + // TODO: check if key already exists. should the consider the edit case // if (varKeys.map(key => key?.trim()).includes(tempPayload.variable.trim())) { // Toast.notify({ diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 9ef624ed97a025..802528e0af7a4b 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -8,7 +8,6 @@ import { useContext } from 'use-context-selector' import produce from 'immer' import { RiDeleteBinLine, - RiQuestionLine, } from '@remixicon/react' import Panel from '../base/feature-panel' import EditModal from './config-modal' @@ -24,7 +23,7 @@ import { checkKeys, getNewVar } from '@/utils/var' import Switch from '@/app/components/base/switch' import Toast from '@/app/components/base/toast' import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' -import ConfirmModal from '@/app/components/base/confirm/common' +import Confirm from '@/app/components/base/confirm' import ConfigContext from '@/context/debug-configuration' import { AppType } from '@/types/app' import type { ExternalDataTool } from '@/models/common' @@ -282,11 +281,13 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar
{t('appDebug.variableTitle')}
{!readonly && ( - - {t('appDebug.variableTip')} -
} selector='config-var-tooltip'> - - + + {t('appDebug.variableTip')} + + } + /> )} } @@ -389,11 +390,10 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar )} {isShowDeleteContextVarModal && ( - { didRemoveVar(removeIndex as number) hideDeleteContextVarModal() diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index 9b12e059b57cfb..515709bff1bf91 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { useContext } from 'use-context-selector' import Panel from '../base/feature-panel' import ParamConfig from './param-config' @@ -33,11 +30,13 @@ const ConfigVision: FC = () => { title={
{t('appDebug.vision.name')}
- - {t('appDebug.vision.description')} -
} selector='config-vision-tooltip'> - -
+ + {t('appDebug.vision.description')} + + } + /> } headerRight={ diff --git a/web/app/components/app/configuration/config-vision/param-config-content.tsx b/web/app/components/app/configuration/config-vision/param-config-content.tsx index 89fad411e70f47..cb01f572541913 100644 --- a/web/app/components/app/configuration/config-vision/param-config-content.tsx +++ b/web/app/components/app/configuration/config-vision/param-config-content.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React from 'react' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import RadioGroup from './radio-group' import ConfigContext from '@/context/debug-configuration' import { Resolution, TransferMethod } from '@/types/app' @@ -37,13 +34,15 @@ const ParamConfigContent: FC = () => {
{t('appDebug.vision.visionSettings.resolution')}
- - {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - + + {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} +
+ } + /> {
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - -
+ + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} + + } + /> = ({ {icon}
{name}
{description} } - selector={`agent-setting-tooltip-${name}`} > -
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 16f2257c38d175..1280fd8928dd35 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -7,13 +7,11 @@ import produce from 'immer' import { RiDeleteBinLine, RiHammerFill, - RiQuestionLine, } from '@remixicon/react' import { useFormattingChangedDispatcher } from '../../../debug/hooks' import SettingBuiltInTool from './setting-built-in-tool' import cn from '@/utils/classnames' import Panel from '@/app/components/app/configuration/base/feature-panel' -import Tooltip from '@/app/components/base/tooltip' import { InfoCircle } from '@/app/components/base/icons/src/vender/line/general' import OperationBtn from '@/app/components/app/configuration/base/operation-btn' import AppIcon from '@/app/components/base/app-icon' @@ -23,7 +21,7 @@ import type { AgentTool } from '@/types/app' import { type Collection, CollectionType } from '@/app/components/tools/types' import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' import AddToolModal from '@/app/components/tools/add-tool-modal' @@ -68,11 +66,13 @@ const AgentTools: FC = () => { title={
{t('appDebug.agent.tools.name')}
- - {t('appDebug.agent.tools.description')} -
} selector='config-tools-tooltip'> - - + + {t('appDebug.agent.tools.description')} +
+ } + /> } headerRight={ @@ -119,19 +119,20 @@ const AgentTools: FC = () => { className={cn((item.isDeleted || item.notAuthor) ? 'line-through opacity-50' : '', 'grow w-0 ml-2 leading-[18px] text-[13px] font-medium text-gray-800 truncate')} > {item.provider_type === CollectionType.builtIn ? item.provider_name : item.tool_label} - {item.tool_name} - +
{(item.isDeleted || item.notAuthor) ? (
-
{ if (item.notAuthor) @@ -139,7 +140,7 @@ const AgentTools: FC = () => { }}>
-
+
{ const newModelConfig = produce(modelConfig, (draft) => { @@ -155,16 +156,17 @@ const AgentTools: FC = () => { ) : (
- -
{ +
{ setCurrentTool(item) setIsShowSettingTool(true) }}>
- +
{ const newModelConfig = produce(modelConfig, (draft) => { diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index 1939dd3ad7a68d..55d5661036c69b 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -282,7 +282,6 @@ const GetAutomaticRes: FC = ({ title={t('appDebug.generate.overwriteTitle')} content={t('appDebug.generate.overwriteMessage')} isShow={showConfirmOverwrite} - onClose={() => setShowConfirmOverwrite(false)} onConfirm={() => { setShowConfirmOverwrite(false) onFinished(res!) diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index 7b369d9d79847f..a528b2288cb028 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -39,10 +39,9 @@ const CardItem: FC = ({
{config.name}
{!config.embedding_available && ( - {t('dataset.unavailable')} + {t('dataset.unavailable')} )}
diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.tsx index be0ae472423863..0de182b0a97fcb 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import type { Props } from './var-picker' import VarPicker from './var-picker' import cn from '@/utils/classnames' @@ -24,13 +21,12 @@ const ContextVar: FC = (props) => {
{t('appDebug.feature.dataSet.queryVariable.title')}
- {t('appDebug.feature.dataSet.queryVariable.tip')} -
} - selector='context-var-tooltip' - > - - + popupContent={ +
+ {t('appDebug.feature.dataSet.queryVariable.tip')} +
+ } + />
diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 3656bf6ea76b54..7f55649dab89cd 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -5,7 +5,6 @@ import type { FC } from 'react' import { useTranslation } from 'react-i18next' import { RiAlertFill, - RiQuestionLine, } from '@remixicon/react' import WeightedScore from './weighted-score' import TopKItem from '@/app/components/base/param-item/top-k-item' @@ -23,7 +22,7 @@ import ModelSelector from '@/app/components/header/account-setting/model-provide import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelConfig } from '@/app/components/workflow/types' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { DataSet, @@ -173,7 +172,7 @@ const ConfigContent: FC = ({ title={(
{t('appDebug.datasetConfig.retrieveOneWay.title')} - {t('dataset.nTo1RetrievalLegacy')} @@ -181,7 +180,7 @@ const ConfigContent: FC = ({ )} >
legacy
-
+
)} description={t('appDebug.datasetConfig.retrieveOneWay.description')} @@ -250,12 +249,15 @@ const ConfigContent: FC = ({ onClick={() => handleRerankModeChange(option.value)} >
{option.label}
- {option.tips}
} - hideArrow - > - - + + {option.tips} +
+ } + popupClassName='ml-0.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' + /> )) } @@ -281,9 +283,15 @@ const ConfigContent: FC = ({ ) }
{t('common.modelProvider.rerankModel.key')}
- {t('common.modelProvider.rerankModel.tip')}}> - - + + {t('common.modelProvider.rerankModel.tip')} + + } + popupClassName='ml-0.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' + />
= ({
{t('common.modelProvider.systemReasoningModel.key')}
- - - + />
= ({ { currentModel && currentModel.status !== ModelStatusEnum.active && ( - + - + ) }
diff --git a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx index e27eec46c8b683..199558f4aa5ec1 100644 --- a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import Panel from '@/app/components/app/configuration/base/feature-panel' import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' import Tooltip from '@/app/components/base/tooltip' @@ -15,13 +12,15 @@ const SuggestedQuestionsAfterAnswer: FC = () => { return ( +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- - {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} -
} selector='suggestion-question-tooltip'> - - + + {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} +
+ } + /> } headerIcon={} diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 82d25bb3aca1d1..432accb0d21049 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -880,7 +880,6 @@ const Configuration: FC = () => { title={t('appDebug.resetConfig.title')} content={t('appDebug.resetConfig.message')} isShow={restoreConfirmOpen} - onClose={() => setRestoreConfirmOpen(false)} onConfirm={resetAppConfig} onCancel={() => setRestoreConfirmOpen(false)} /> @@ -890,7 +889,6 @@ const Configuration: FC = () => { title={t('appDebug.trailUseGPT4Info.title')} content={t('appDebug.trailUseGPT4Info.description')} isShow={showUseGPT4Confirm} - onClose={() => setShowUseGPT4Confirm(false)} onConfirm={() => { setShowAccountSettingModal({ payload: 'provider' }) setShowUseGPT4Confirm(false) diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index 0192024c83e527..dc9b0a4333c085 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -16,7 +16,7 @@ import { AppType, ModelModeType } from '@/types/app' import Select from '@/app/components/base/select' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import Button from '@/app/components/base/button' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import TextGenerationImageUploader from '@/app/components/base/image-uploader/text-generation-image-uploader' import type { VisionFile, VisionSettings } from '@/types/app' @@ -207,6 +207,7 @@ const PromptValuePanel: FC = ({ {canNotRun ? ( {renderRunButton()} ) diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx index b2c6792107e322..111c380afc0f5c 100644 --- a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx +++ b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx @@ -8,7 +8,7 @@ import { MessageCheckRemove, MessageFastPlus } from '@/app/components/base/icons import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' import { Edit04 } from '@/app/components/base/icons/src/vender/line/general' import RemoveAnnotationConfirmModal from '@/app/components/app/annotation/remove-annotation-confirm-modal' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { addAnnotation, delAnnotation } from '@/service/annotation' import Toast from '@/app/components/base/toast' import { useProviderContext } from '@/context/provider-context' @@ -99,8 +99,9 @@ const CacheCtrlBtn: FC = ({ ) : answer ? ( -
= ({ >
-
+
) : null } -
= ({ >
-
+
{title}
- {tooltip}
} - > - - + />
{children}
@@ -103,7 +98,7 @@ const AnnotationReplyConfig: FC = ({ let isEmbeddingModelChanged = false if ( embeddingModel.embedding_model_name !== annotationConfig.embedding_model.embedding_model_name - && embeddingModel.embedding_provider_name !== annotationConfig.embedding_model.embedding_provider_name + || embeddingModel.embedding_provider_name !== annotationConfig.embedding_model.embedding_provider_name ) { await onEmbeddingChange(embeddingModel) isEmbeddingModelChanged = true diff --git a/web/app/components/app/configuration/toolbox/index.tsx b/web/app/components/app/configuration/toolbox/index.tsx index 488ca86b905c1a..00ea301a42d8a5 100644 --- a/web/app/components/app/configuration/toolbox/index.tsx +++ b/web/app/components/app/configuration/toolbox/index.tsx @@ -32,7 +32,7 @@ const Toolbox: FC = ({ ) } { - (showAnnotation || true) && ( + showAnnotation && ( {
{t('appDebug.feature.tools.title')}
- {t('appDebug.feature.tools.tips')}}> - - + + {t('appDebug.feature.tools.tips')} + + } + /> { !expanded && !!externalDataToolsConfig.length && ( @@ -143,7 +146,7 @@ const Tools = () => { background={item.icon_background} />
{item.label}
-
{ > {item.variable}
-
+
{ const [appMode, setAppMode] = useState('chat') const [showChatBotType, setShowChatBotType] = useState(true) - const [emoji, setEmoji] = useState({ icon: '🤖', icon_background: '#FFEAD5' }) - const [showEmojiPicker, setShowEmojiPicker] = useState(false) + const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [name, setName] = useState('') const [description, setDescription] = useState('') @@ -66,8 +67,9 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { const app = await createApp({ name, description, - icon: emoji.icon, - icon_background: emoji.icon_background, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, mode: appMode, }) notify({ type: 'success', message: t('app.newApp.appCreated') }) @@ -81,7 +83,7 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) } isCreatingRef.current = false - }, [name, notify, t, appMode, emoji.icon, emoji.icon_background, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor]) + }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor]) return ( {
{t('app.newApp.captionAppType')}
- {t('app.newApp.chatbotDescription')}
} @@ -118,9 +119,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.chatbot')}
- - +
{t('app.newApp.completionDescription')}
@@ -141,9 +141,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.newApp.completeApp')}
- - + {t('app.newApp.agentDescription')} } @@ -162,9 +161,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.agent')}
-
- +
{t('app.newApp.workflowDescription')}
@@ -186,7 +184,7 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.workflow')}
BETA -
+ {showChatBotType && ( @@ -269,7 +267,14 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.newApp.captionName')}
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.icon} background={emoji.icon_background} /> + { setShowAppIconPicker(true) }} + /> setName(e.target.value)} @@ -277,14 +282,13 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { className='grow h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg border border-transparent outline-none appearance-none caret-primary-600 placeholder:text-gray-400 hover:bg-gray-50 hover:border hover:border-gray-300 focus:bg-gray-50 focus:border focus:border-gray-300 focus:shadow-xs' />
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon: '🤖', icon_background: '#FFEAD5' }) - setShowEmojiPicker(false) + setShowAppIconPicker(false) }} />}
diff --git a/web/app/components/app/duplicate-modal/index.tsx b/web/app/components/app/duplicate-modal/index.tsx index 6595972de615da..59a53101dbed2d 100644 --- a/web/app/components/app/duplicate-modal/index.tsx +++ b/web/app/components/app/duplicate-modal/index.tsx @@ -1,32 +1,39 @@ 'use client' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' +import AppIconPicker from '../../base/app-icon-picker' import s from './style.module.css' import cn from '@/utils/classnames' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import Toast from '@/app/components/base/toast' import AppIcon from '@/app/components/base/app-icon' -import EmojiPicker from '@/app/components/base/emoji-picker' import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' +import type { AppIconType } from '@/types/app' + export type DuplicateAppModalProps = { appName: string + icon_type: AppIconType | null icon: string - icon_background: string + icon_background?: string | null + icon_url?: string | null show: boolean onConfirm: (info: { name: string + icon_type: AppIconType icon: string - icon_background: string + icon_background?: string | null }) => Promise onHide: () => void } const DuplicateAppModal = ({ appName, + icon_type, icon, icon_background, + icon_url, show = false, onConfirm, onHide, @@ -35,8 +42,12 @@ const DuplicateAppModal = ({ const [name, setName] = React.useState(appName) - const [showEmojiPicker, setShowEmojiPicker] = useState(false) - const [emoji, setEmoji] = useState({ icon, icon_background }) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [appIcon, setAppIcon] = useState( + icon_type === 'image' + ? { type: 'image' as const, url: icon_url, fileId: icon } + : { type: 'emoji' as const, icon, background: icon_background }, + ) const { plan, enableBilling } = useProviderContext() const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) @@ -48,7 +59,9 @@ const DuplicateAppModal = ({ } onConfirm({ name, - ...emoji, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, }) onHide() } @@ -65,7 +78,15 @@ const DuplicateAppModal = ({
{t('explore.appCustomize.subTitle')}
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.icon} background={emoji.icon_background} /> + { setShowAppIconPicker(true) }} + className='cursor-pointer' + iconType={appIcon.type} + icon={appIcon.type === 'image' ? appIcon.fileId : appIcon.icon} + background={appIcon.type === 'image' ? undefined : appIcon.background} + imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} + /> setName(e.target.value)} @@ -79,14 +100,16 @@ const DuplicateAppModal = ({
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) + setShowAppIconPicker(false) }} />} diff --git a/web/app/components/app/log-annotation/index.tsx b/web/app/components/app/log-annotation/index.tsx index 852e57035c36e1..e0ace764f0a44c 100644 --- a/web/app/components/app/log-annotation/index.tsx +++ b/web/app/components/app/log-annotation/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' import cn from '@/utils/classnames' @@ -23,10 +23,14 @@ const LogAnnotation: FC = ({ const router = useRouter() const appDetail = useAppStore(state => state.appDetail) - const options = [ - { value: PageType.log, text: t('appLog.title') }, - { value: PageType.annotation, text: t('appAnnotation.title') }, - ] + const options = useMemo(() => { + if (appDetail?.mode === 'completion') + return [{ value: PageType.log, text: t('appLog.title') }] + return [ + { value: PageType.log, text: t('appLog.title') }, + { value: PageType.annotation, text: t('appAnnotation.title') }, + ] + }, [appDetail]) if (!appDetail) { return ( diff --git a/web/app/components/app/log/filter.tsx b/web/app/components/app/log/filter.tsx index 80a58bb5a2a841..0552b44d16b3cb 100644 --- a/web/app/components/app/log/filter.tsx +++ b/web/app/components/app/log/filter.tsx @@ -10,6 +10,7 @@ import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import type { QueryParam } from './index' import { SimpleSelect } from '@/app/components/base/select' +import Sort from '@/app/components/base/sort' import { fetchAnnotationsCount } from '@/service/log' dayjs.extend(quarterOfYear) @@ -28,18 +29,19 @@ export const TIME_PERIOD_LIST = [ ] type IFilterProps = { + isChatMode?: boolean appId: string queryParams: QueryParam setQueryParams: (v: QueryParam) => void } -const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilterProps) => { +const Filter: FC = ({ isChatMode, appId, queryParams, setQueryParams }: IFilterProps) => { const { data } = useSWR({ url: `/apps/${appId}/annotations/count` }, fetchAnnotationsCount) const { t } = useTranslation() if (!data) return null return ( -
+
({ value: item.value, name: t(`appLog.filter.period.${item.name}`) }))} className='mt-0 !w-40' @@ -68,7 +70,7 @@ const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilte { @@ -76,6 +78,22 @@ const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilte }} />
+ {isChatMode && ( + <> +
+ { + setQueryParams({ ...queryParams, sort_by: value as string }) + }} + /> + + )}
) } diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index 828004d6f377ec..dd6ebd08f02041 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -24,6 +24,7 @@ export type QueryParam = { period?: number | string annotation_status?: string keyword?: string + sort_by?: string } const ThreeDotsIcon = ({ className }: SVGProps) => { @@ -52,18 +53,26 @@ const EmptyElement: FC<{ appUrl: string }> = ({ appUrl }) => { const Logs: FC = ({ appDetail }) => { const { t } = useTranslation() - const [queryParams, setQueryParams] = useState({ period: 7, annotation_status: 'all' }) + const [queryParams, setQueryParams] = useState({ + period: 7, + annotation_status: 'all', + sort_by: '-created_at', + }) const [currPage, setCurrPage] = React.useState(0) + // Get the app type first + const isChatMode = appDetail.mode !== 'completion' + const query = { page: currPage + 1, limit: APP_PAGE_LIMIT, ...(queryParams.period !== 'all' ? { start: dayjs().subtract(queryParams.period as number, 'day').startOf('day').format('YYYY-MM-DD HH:mm'), - end: dayjs().format('YYYY-MM-DD HH:mm'), + end: dayjs().endOf('day').format('YYYY-MM-DD HH:mm'), } : {}), + ...(isChatMode ? { sort_by: queryParams.sort_by } : {}), ...omit(queryParams, ['period']), } @@ -73,9 +82,6 @@ const Logs: FC = ({ appDetail }) => { return appType } - // Get the app type first - const isChatMode = appDetail.mode !== 'completion' - // When the details are obtained, proceed to the next request const { data: chatConversations, mutate: mutateChatList } = useSWR(() => isChatMode ? { @@ -97,7 +103,7 @@ const Logs: FC = ({ appDetail }) => {

{t('appLog.description')}

- + {total === undefined ? : total > 0 diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 00d7ffb0fbb206..0bc118c46fdb49 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -5,10 +5,9 @@ import useSWR from 'swr' import { HandThumbDownIcon, HandThumbUpIcon, - InformationCircleIcon, XMarkIcon, } from '@heroicons/react/24/outline' -import { RiEditFill } from '@remixicon/react' +import { RiEditFill, RiQuestionLine } from '@remixicon/react' import { get } from 'lodash-es' import InfiniteScroll from 'react-infinite-scroll-component' import dayjs from 'dayjs' @@ -20,7 +19,6 @@ import { useTranslation } from 'react-i18next' import s from './style.module.css' import VarPanel from './var-panel' import cn from '@/utils/classnames' -import { randomString } from '@/utils' import type { FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type' import type { Annotation, ChatConversationFullDetailResponse, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationFullDetailResponse, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log' import type { App } from '@/types/app' @@ -28,7 +26,6 @@ import Loading from '@/app/components/base/loading' import Drawer from '@/app/components/base/drawer' import Popover from '@/app/components/base/popover' import Chat from '@/app/components/base/chat/chat' -import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import { fetchChatConversationDetail, fetchChatMessages, fetchCompletionConversationDetail, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' import { TONE_LIST } from '@/config' @@ -42,7 +39,7 @@ import MessageLogModal from '@/app/components/base/message-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { CopyIcon } from '@/app/components/base/copy-icon' dayjs.extend(utc) @@ -126,6 +123,7 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t tokens: item.answer_tokens + item.message_tokens, latency: item.provider_response_latency.toFixed(2), }, + citation: item.metadata?.retriever_resources, annotation: (() => { if (item.annotation_hit_history) { return { @@ -345,11 +343,11 @@ function DetailPanel{isChatMode ? t('appLog.detail.conversationId') : t('appLog.detail.time')}
{isChatMode && (
- +
{detail.id}
-
+
)} @@ -379,7 +377,7 @@ function DetailPanel {targetTone} - + } htmlContent={
@@ -640,13 +638,12 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => { return ( {`${t('appLog.detail.annotationTip', { user: annotation?.account?.name })} ${formatTime(annotation?.created_at || dayjs().unix(), 'MM-DD hh:mm A')}`} } - className={(isHighlight && !isChatMode) ? '' : '!hidden'} - selector={`highlight-${randomString(16)}`} + popupClassName={(isHighlight && !isChatMode) ? '' : '!hidden'} >
{value || '-'} @@ -670,17 +667,18 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) - {t('appLog.table.header.time')} - {t('appLog.table.header.endUser')} {isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')} + {t('appLog.table.header.endUser')} {isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')} {t('appLog.table.header.userRate')} {t('appLog.table.header.adminRate')} + {t('appLog.table.header.updatedTime')} + {t('appLog.table.header.time')} {logs.data.map((log: any) => { - const endUser = log.from_end_user_session_id + const endUser = log.from_end_user_session_id || log.from_account_name const leftValue = get(log, isChatMode ? 'name' : 'message.inputs.query') || (!isChatMode ? (get(log, 'message.query') || get(log, 'message.inputs.default_input')) : '') || '' const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer') return = ({ logs, appDetail, onRefresh }) setCurrentConversation(log) }}> {!log.read_at && } - {formatTime(log.created_at, t('appLog.dateTimeFormat') as string)} - {renderTdValue(endUser || defaultValue, !endUser)} {renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)} + {renderTdValue(endUser || defaultValue, !endUser)} {renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)} @@ -717,6 +714,8 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) } + {formatTime(log.updated_at, t('appLog.dateTimeFormat') as string)} + {formatTime(log.created_at, t('appLog.dateTimeFormat') as string)} })} diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx index ae44664f9fb413..f9f5c1fbfff03a 100644 --- a/web/app/components/app/overview/appCard.tsx +++ b/web/app/components/app/overview/appCard.tsx @@ -27,10 +27,11 @@ import ShareQRCode from '@/app/components/base/qrcode' import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button' import type { AppDetailResponse } from '@/models/app' import { useAppContext } from '@/context/app-context' +import type { AppSSO } from '@/types/app' export type IAppCardProps = { className?: string - appInfo: AppDetailResponse + appInfo: AppDetailResponse & Partial cardType?: 'api' | 'webapp' customBgColor?: string onChangeStatus: (val: boolean) => Promise @@ -53,7 +54,7 @@ function AppCard({ }: IAppCardProps) { const router = useRouter() const pathname = usePathname() - const { currentWorkspace, isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() + const { isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() const [showSettingsModal, setShowSettingsModal] = useState(false) const [showEmbedded, setShowEmbedded] = useState(false) const [showCustomizeModal, setShowCustomizeModal] = useState(false) @@ -133,8 +134,8 @@ function AppCard({ return (
@@ -175,7 +176,6 @@ function AppCard({ {isApp && } {/* button copy link/ button regenerate */} @@ -185,7 +185,6 @@ function AppCard({ title={t('appOverview.overview.appInfo.regenerate')} content={t('appOverview.overview.appInfo.regenerateNotice')} isShow={showConfirmDelete} - onClose={() => setShowConfirmDelete(false)} onConfirm={() => { onGenCode() setShowConfirmDelete(false) @@ -195,16 +194,15 @@ function AppCard({ )} {isApp && isCurrentWorkspaceManager && (
setShowConfirmDelete(true)} >
@@ -227,11 +225,10 @@ function AppCard({ disabled={disabled} >
diff --git a/web/app/components/app/overview/appChart.tsx b/web/app/components/app/overview/appChart.tsx index 7fd316a34b028e..e0788bcda3bed6 100644 --- a/web/app/components/app/overview/appChart.tsx +++ b/web/app/components/app/overview/appChart.tsx @@ -10,8 +10,8 @@ import { useTranslation } from 'react-i18next' import { formatNumber } from '@/utils/format' import Basic from '@/app/components/app-sidebar/basic' import Loading from '@/app/components/base/loading' -import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppTokenCostsResponse } from '@/models/app' -import { getAppDailyConversations, getAppDailyEndUsers, getAppStatistics, getAppTokenCosts, getWorkflowDailyConversations } from '@/service/apps' +import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppDailyMessagesResponse, AppTokenCostsResponse } from '@/models/app' +import { getAppDailyConversations, getAppDailyEndUsers, getAppDailyMessages, getAppStatistics, getAppTokenCosts, getWorkflowDailyConversations } from '@/service/apps' const valueFormatter = (v: string | number) => v const COLOR_TYPE_MAP = { @@ -36,12 +36,15 @@ const COMMON_COLOR_MAP = { } type IColorType = 'green' | 'orange' | 'blue' -type IChartType = 'conversations' | 'endUsers' | 'costs' | 'workflowCosts' +type IChartType = 'messages' | 'conversations' | 'endUsers' | 'costs' | 'workflowCosts' type IChartConfigType = { colorType: IColorType; showTokens?: boolean } const commonDateFormat = 'MMM D, YYYY' const CHART_TYPE_CONFIG: Record = { + messages: { + colorType: 'green', + }, conversations: { colorType: 'green', }, @@ -89,7 +92,7 @@ export type IChartProps = { unit?: string yMax?: number chartType: IChartType - chartData: AppDailyConversationsResponse | AppDailyEndUsersResponse | AppTokenCostsResponse | { data: Array<{ date: string; count: number }> } + chartData: AppDailyMessagesResponse | AppDailyConversationsResponse | AppDailyEndUsersResponse | AppTokenCostsResponse | { data: Array<{ date: string; count: number }> } } const Chart: React.FC = ({ @@ -258,6 +261,20 @@ const getDefaultChartData = ({ start, end, key = 'count' }: { start: string; end }) } +export const MessagesChart: FC = ({ id, period }) => { + const { t } = useTranslation() + const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-messages`, params: period.query }, getAppDailyMessages) + if (!response) + return + const noDataFlag = !response.data || response.data.length === 0 + return +} + export const ConversationsChart: FC = ({ id, period }) => { const { t } = useTranslation() const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-conversations`, params: period.query }, getAppDailyConversations) @@ -265,7 +282,7 @@ export const ConversationsChart: FC = ({ id, period }) => { return const noDataFlag = !response.data || response.data.length === 0 return
diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 88d5c2d9094585..02eff06c0a7b41 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -4,21 +4,25 @@ import React, { useEffect, useState } from 'react' import { ChevronRightIcon } from '@heroicons/react/20/solid' import Link from 'next/link' import { Trans, useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import s from './style.module.css' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import AppIcon from '@/app/components/base/app-icon' +import Switch from '@/app/components/base/switch' import { SimpleSelect } from '@/app/components/base/select' import type { AppDetailResponse } from '@/models/app' -import type { Language } from '@/types/app' -import EmojiPicker from '@/app/components/base/emoji-picker' +import type { AppIconType, AppSSO, Language } from '@/types/app' import { useToastContext } from '@/app/components/base/toast' - import { languages } from '@/i18n/language' +import Tooltip from '@/app/components/base/tooltip' +import AppContext from '@/context/app-context' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import AppIconPicker from '@/app/components/base/app-icon-picker' export type ISettingsModalProps = { isChat: boolean - appInfo: AppDetailResponse + appInfo: AppDetailResponse & Partial isShow: boolean defaultValue?: string onClose: () => void @@ -35,9 +39,11 @@ export type ConfigParams = { copyright: string privacy_policy: string custom_disclaimer: string + icon_type: AppIconType icon: string - icon_background: string + icon_background?: string show_workflow_steps: boolean + enable_sso?: boolean } const prefixSettings = 'appOverview.overview.appInfo.settings' @@ -49,11 +55,15 @@ const SettingsModal: FC = ({ onClose, onSave, }) => { + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) - const { icon, icon_background } = appInfo const { title, + icon_type, + icon, + icon_background, + icon_url, description, chat_color_theme, chat_color_theme_inverted, @@ -72,13 +82,18 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + enable_sso: appInfo.enable_sso, }) const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) const { t } = useTranslation() - // Emoji Picker - const [showEmojiPicker, setShowEmojiPicker] = useState(false) - const [emoji, setEmoji] = useState({ icon, icon_background }) + + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [appIcon, setAppIcon] = useState( + icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }, + ) useEffect(() => { setInputInfo({ @@ -90,9 +105,12 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + enable_sso: appInfo.enable_sso, }) setLanguage(default_language) - setEmoji({ icon, icon_background }) + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) }, [appInfo]) const onHide = () => { @@ -135,9 +153,11 @@ const SettingsModal: FC = ({ copyright: inputInfo.copyright, privacy_policy: inputInfo.privacyPolicy, custom_disclaimer: inputInfo.customDisclaimer, - icon: emoji.icon, - icon_background: emoji.icon_background, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, show_workflow_steps: inputInfo.show_workflow_steps, + enable_sso: inputInfo.enable_sso, } await onSave?.(params) setSaveLoading(false) @@ -167,10 +187,12 @@ const SettingsModal: FC = ({
{t(`${prefixSettings}.webName`)}
{ setShowEmojiPicker(true) }} + onClick={() => { setShowAppIconPicker(true) }} className='cursor-pointer !mr-3 self-center' - icon={emoji.icon} - background={emoji.icon_background} + iconType={appIcon.type} + icon={appIcon.type === 'image' ? appIcon.fileId : appIcon.icon} + background={appIcon.type === 'image' ? undefined : appIcon.background} + imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} /> = ({ defaultValue={language} onSelect={item => setLanguage(item.value as Language)} /> - {(appInfo.mode === 'workflow' || appInfo.mode === 'advanced-chat') && <> -
{t(`${prefixSettings}.workflow.title`)}
- setInputInfo({ ...inputInfo, show_workflow_steps: item.value === 'true' })} - /> - } +
+

{t(`${prefixSettings}.workflow.title`)}

+
+
{t(`${prefixSettings}.workflow.subTitle`)}
+ setInputInfo({ ...inputInfo, show_workflow_steps: v })} + /> +
+

{t(`${prefixSettings}.workflow.showDesc`)}

+
+ {isChat && <>
{t(`${prefixSettings}.chatColorTheme`)}

{t(`${prefixSettings}.chatColorThemeDesc`)}

} + {systemFeatures.enable_web_sso_switch_component &&
+

{t(`${prefixSettings}.sso.label`)}

+
+
{t(`${prefixSettings}.sso.title`)}
+ {t(`${prefixSettings}.sso.tooltip`)}
+ } + asChild={false} + > + setInputInfo({ ...inputInfo, enable_sso: v })}> + +
+

{t(`${prefixSettings}.sso.description`)}

+
} {!isShowMore &&
setIsShowMore(true)}>
{t(`${prefixSettings}.more.entry`)}
@@ -250,14 +293,16 @@ const SettingsModal: FC = ({
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon: appInfo.site.icon, icon_background: appInfo.site.icon_background }) - setShowEmojiPicker(false) + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) + setShowAppIconPicker(false) }} />} diff --git a/web/app/components/app/store.ts b/web/app/components/app/store.ts index a89b96d65dfbef..0209102372d9c6 100644 --- a/web/app/components/app/store.ts +++ b/web/app/components/app/store.ts @@ -1,9 +1,9 @@ import { create } from 'zustand' -import type { App } from '@/types/app' +import type { App, AppSSO } from '@/types/app' import type { IChatItem } from '@/app/components/base/chat/chat/type' type State = { - appDetail?: App + appDetail?: App & Partial appSidebarExpand: string currentLogItem?: IChatItem currentLogModalActiveTab: string @@ -13,7 +13,7 @@ type State = { } type Action = { - setAppDetail: (appDetail?: App) => void + setAppDetail: (appDetail?: App & Partial) => void setAppSiderbarExpand: (state: string) => void setCurrentLogItem: (item?: IChatItem) => void setCurrentLogModalActiveTab: (tab: string) => void diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index e5ac6ed55ec8f1..82f57e1f5a107c 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' +import AppIconPicker from '../../base/app-icon-picker' import s from './style.module.css' import cn from '@/utils/classnames' import Button from '@/app/components/base/button' @@ -15,7 +16,6 @@ import { deleteApp, switchApp } from '@/service/apps' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' -import EmojiPicker from '@/app/components/base/emoji-picker' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { getRedirection } from '@/utils/app-redirection' import type { App } from '@/types/app' @@ -41,8 +41,13 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo const { plan, enableBilling } = useProviderContext() const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) - const [emoji, setEmoji] = useState({ icon: appDetail.icon, icon_background: appDetail.icon_background }) - const [showEmojiPicker, setShowEmojiPicker] = useState(false) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [appIcon, setAppIcon] = useState( + appDetail.icon_type === 'image' + ? { type: 'image' as const, url: appDetail.icon_url, fileId: appDetail.icon } + : { type: 'emoji' as const, icon: appDetail.icon, background: appDetail.icon_background }, + ) + const [name, setName] = useState(`${appDetail.name}(copy)`) const [removeOriginal, setRemoveOriginal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) @@ -52,8 +57,9 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo const { new_app_id: newAppID } = await switchApp({ appID: appDetail.id, name, - icon: emoji.icon, - icon_background: emoji.icon_background, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, }) if (onSuccess) onSuccess() @@ -106,7 +112,15 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo
{t('app.switchLabel')}
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.icon} background={emoji.icon_background} /> + { setShowAppIconPicker(true) }} + className='cursor-pointer' + iconType={appIcon.type} + icon={appIcon.type === 'image' ? appIcon.fileId : appIcon.icon} + background={appIcon.type === 'image' ? undefined : appIcon.background} + imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} + /> setName(e.target.value)} @@ -114,14 +128,16 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo className='grow h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg border border-transparent outline-none appearance-none caret-primary-600 placeholder:text-gray-400 hover:bg-gray-50 hover:border hover:border-gray-300 focus:bg-gray-50 focus:border focus:border-gray-300 focus:shadow-xs' />
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon: appDetail.icon, icon_background: appDetail.icon_background }) - setShowEmojiPicker(false) + setAppIcon(appDetail.icon_type === 'image' + ? { type: 'image' as const, url: appDetail.icon_url, fileId: appDetail.icon } + : { type: 'emoji' as const, icon: appDetail.icon, background: appDetail.icon_background }) + setShowAppIconPicker(false) }} />}
@@ -147,10 +163,6 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo setShowConfirmDelete(false) setRemoveOriginal(false) }} - onClose={() => { - setShowConfirmDelete(false) - setRemoveOriginal(false) - }} /> )} diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index f4707dce59d637..5a81c8cd8b5a2a 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -91,7 +91,7 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { {logs.data.map((log: WorkflowAppLogDetail) => { - const endUser = log.created_by_end_user ? log.created_by_end_user.session_id : defaultValue + const endUser = log.created_by_end_user ? log.created_by_end_user.session_id : log.created_by_account ? log.created_by_account.name : defaultValue return = ({ - status, elapsed_time, total_tokens, error, diff --git a/web/app/components/base/app-icon-picker/Uploader.tsx b/web/app/components/base/app-icon-picker/Uploader.tsx new file mode 100644 index 00000000000000..4ddaa404475e6d --- /dev/null +++ b/web/app/components/base/app-icon-picker/Uploader.tsx @@ -0,0 +1,97 @@ +'use client' + +import type { ChangeEvent, FC } from 'react' +import { createRef, useEffect, useState } from 'react' +import type { Area } from 'react-easy-crop' +import Cropper from 'react-easy-crop' +import classNames from 'classnames' + +import { ImagePlus } from '../icons/src/vender/line/images' +import { useDraggableUploader } from './hooks' +import { ALLOW_FILE_EXTENSIONS } from '@/types/app' + +type UploaderProps = { + className?: string + onImageCropped?: (tempUrl: string, croppedAreaPixels: Area, fileName: string) => void +} + +const Uploader: FC = ({ + className, + onImageCropped, +}) => { + const [inputImage, setInputImage] = useState<{ file: File; url: string }>() + useEffect(() => { + return () => { + if (inputImage) + URL.revokeObjectURL(inputImage.url) + } + }, [inputImage]) + + const [crop, setCrop] = useState({ x: 0, y: 0 }) + const [zoom, setZoom] = useState(1) + + const onCropComplete = async (_: Area, croppedAreaPixels: Area) => { + if (!inputImage) + return + onImageCropped?.(inputImage.url, croppedAreaPixels, inputImage.file.name) + } + + const handleLocalFileInput = (e: ChangeEvent) => { + const file = e.target.files?.[0] + if (file) + setInputImage({ file, url: URL.createObjectURL(file) }) + } + + const { + isDragActive, + handleDragEnter, + handleDragOver, + handleDragLeave, + handleDrop, + } = useDraggableUploader((file: File) => setInputImage({ file, url: URL.createObjectURL(file) })) + + const inputRef = createRef() + + return ( +
+
+ { + !inputImage + ? <> + +
+ Drop your image here, or  + + ((e.target as HTMLInputElement).value = '')} + accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')} + onChange={handleLocalFileInput} + /> +
+
Supports PNG, JPG, JPEG, WEBP and GIF
+ + : + } +
+
+ ) +} + +export default Uploader diff --git a/web/app/components/base/app-icon-picker/hooks.tsx b/web/app/components/base/app-icon-picker/hooks.tsx new file mode 100644 index 00000000000000..b3f67c0dcae930 --- /dev/null +++ b/web/app/components/base/app-icon-picker/hooks.tsx @@ -0,0 +1,43 @@ +import { useCallback, useState } from 'react' + +export const useDraggableUploader = (setImageFn: (file: File) => void) => { + const [isDragActive, setIsDragActive] = useState(false) + + const handleDragEnter = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragActive(true) + }, []) + + const handleDragOver = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + }, []) + + const handleDragLeave = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragActive(false) + }, []) + + const handleDrop = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragActive(false) + + const file = e.dataTransfer.files[0] + + if (!file) + return + + setImageFn(file) + }, [setImageFn]) + + return { + handleDragEnter, + handleDragOver, + handleDragLeave, + handleDrop, + isDragActive, + } +} diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx new file mode 100644 index 00000000000000..825481547501a8 --- /dev/null +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -0,0 +1,139 @@ +import type { FC } from 'react' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import type { Area } from 'react-easy-crop' +import Modal from '../modal' +import Divider from '../divider' +import Button from '../button' +import { ImagePlus } from '../icons/src/vender/line/images' +import { useLocalFileUploader } from '../image-uploader/hooks' +import EmojiPickerInner from '../emoji-picker/Inner' +import Uploader from './Uploader' +import s from './style.module.css' +import getCroppedImg from './utils' +import type { AppIconType, ImageFile } from '@/types/app' +import cn from '@/utils/classnames' +import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' +export type AppIconEmojiSelection = { + type: 'emoji' + icon: string + background: string +} + +export type AppIconImageSelection = { + type: 'image' + fileId: string + url: string +} + +export type AppIconSelection = AppIconEmojiSelection | AppIconImageSelection + +type AppIconPickerProps = { + onSelect?: (payload: AppIconSelection) => void + onClose?: () => void + className?: string +} + +const AppIconPicker: FC = ({ + onSelect, + onClose, + className, +}) => { + const { t } = useTranslation() + + const tabs = [ + { key: 'emoji', label: t('app.iconPicker.emoji'), icon: 🤖 }, + { key: 'image', label: t('app.iconPicker.image'), icon: }, + ] + const [activeTab, setActiveTab] = useState('emoji') + + const [emoji, setEmoji] = useState<{ emoji: string; background: string }>() + const handleSelectEmoji = useCallback((emoji: string, background: string) => { + setEmoji({ emoji, background }) + }, [setEmoji]) + + const [uploading, setUploading] = useState() + + const { handleLocalFileUpload } = useLocalFileUploader({ + limit: 3, + disabled: false, + onUpload: (imageFile: ImageFile) => { + if (imageFile.fileId) { + setUploading(false) + onSelect?.({ + type: 'image', + fileId: imageFile.fileId, + url: imageFile.url, + }) + } + }, + }) + + const [imageCropInfo, setImageCropInfo] = useState<{ tempUrl: string; croppedAreaPixels: Area; fileName: string }>() + const handleImageCropped = async (tempUrl: string, croppedAreaPixels: Area, fileName: string) => { + setImageCropInfo({ tempUrl, croppedAreaPixels, fileName }) + } + + const handleSelect = async () => { + if (activeTab === 'emoji') { + if (emoji) { + onSelect?.({ + type: 'emoji', + icon: emoji.emoji, + background: emoji.background, + }) + } + } + else { + if (!imageCropInfo) + return + setUploading(true) + const blob = await getCroppedImg(imageCropInfo.tempUrl, imageCropInfo.croppedAreaPixels) + const file = new File([blob], imageCropInfo.fileName, { type: blob.type }) + handleLocalFileUpload(file) + } + } + + return { }} + isShow + closable={false} + wrapperClassName={className} + className={cn(s.container, '!w-[362px] !p-0')} + > + {!DISABLE_UPLOAD_IMAGE_AS_ICON &&
+
+ {tabs.map(tab => ( + + ))} +
+
} + + + + + + + +
+ + + +
+
+} + +export default AppIconPicker diff --git a/web/app/components/base/app-icon-picker/style.module.css b/web/app/components/base/app-icon-picker/style.module.css new file mode 100644 index 00000000000000..5facb3560a04d3 --- /dev/null +++ b/web/app/components/base/app-icon-picker/style.module.css @@ -0,0 +1,12 @@ +.container { + display: flex; + flex-direction: column; + align-items: flex-start; + width: 362px; + max-height: 552px; + + border: 0.5px solid #EAECF0; + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); + border-radius: 12px; + background: #fff; +} diff --git a/web/app/components/base/app-icon-picker/utils.ts b/web/app/components/base/app-icon-picker/utils.ts new file mode 100644 index 00000000000000..0c90e96febda83 --- /dev/null +++ b/web/app/components/base/app-icon-picker/utils.ts @@ -0,0 +1,98 @@ +export const createImage = (url: string) => + new Promise((resolve, reject) => { + const image = new Image() + image.addEventListener('load', () => resolve(image)) + image.addEventListener('error', error => reject(error)) + image.setAttribute('crossOrigin', 'anonymous') // needed to avoid cross-origin issues on CodeSandbox + image.src = url + }) + +export function getRadianAngle(degreeValue: number) { + return (degreeValue * Math.PI) / 180 +} + +/** + * Returns the new bounding area of a rotated rectangle. + */ +export function rotateSize(width: number, height: number, rotation: number) { + const rotRad = getRadianAngle(rotation) + + return { + width: + Math.abs(Math.cos(rotRad) * width) + Math.abs(Math.sin(rotRad) * height), + height: + Math.abs(Math.sin(rotRad) * width) + Math.abs(Math.cos(rotRad) * height), + } +} + +/** + * This function was adapted from the one in the ReadMe of https://github.com/DominicTobias/react-image-crop + */ +export default async function getCroppedImg( + imageSrc: string, + pixelCrop: { x: number; y: number; width: number; height: number }, + rotation = 0, + flip = { horizontal: false, vertical: false }, +): Promise { + const image = await createImage(imageSrc) + const canvas = document.createElement('canvas') + const ctx = canvas.getContext('2d') + + if (!ctx) + throw new Error('Could not create a canvas context') + + const rotRad = getRadianAngle(rotation) + + // calculate bounding box of the rotated image + const { width: bBoxWidth, height: bBoxHeight } = rotateSize( + image.width, + image.height, + rotation, + ) + + // set canvas size to match the bounding box + canvas.width = bBoxWidth + canvas.height = bBoxHeight + + // translate canvas context to a central location to allow rotating and flipping around the center + ctx.translate(bBoxWidth / 2, bBoxHeight / 2) + ctx.rotate(rotRad) + ctx.scale(flip.horizontal ? -1 : 1, flip.vertical ? -1 : 1) + ctx.translate(-image.width / 2, -image.height / 2) + + // draw rotated image + ctx.drawImage(image, 0, 0) + + const croppedCanvas = document.createElement('canvas') + + const croppedCtx = croppedCanvas.getContext('2d') + + if (!croppedCtx) + throw new Error('Could not create a canvas context') + + // Set the size of the cropped canvas + croppedCanvas.width = pixelCrop.width + croppedCanvas.height = pixelCrop.height + + // Draw the cropped image onto the new canvas + croppedCtx.drawImage( + canvas, + pixelCrop.x, + pixelCrop.y, + pixelCrop.width, + pixelCrop.height, + 0, + 0, + pixelCrop.width, + pixelCrop.height, + ) + + return new Promise((resolve, reject) => { + croppedCanvas.toBlob((file) => { + if (file) + resolve(file) + else + reject(new Error('Could not create a blob')) + }, 'image/jpeg') + }) +} diff --git a/web/app/components/base/app-icon/index.tsx b/web/app/components/base/app-icon/index.tsx index 9d8cf28beda7a1..5e7378c08734aa 100644 --- a/web/app/components/base/app-icon/index.tsx +++ b/web/app/components/base/app-icon/index.tsx @@ -1,17 +1,21 @@ -import type { FC } from 'react' +'use client' -import data from '@emoji-mart/data' +import type { FC } from 'react' import { init } from 'emoji-mart' +import data from '@emoji-mart/data' import style from './style.module.css' import classNames from '@/utils/classnames' +import type { AppIconType } from '@/types/app' init({ data }) export type AppIconProps = { size?: 'xs' | 'tiny' | 'small' | 'medium' | 'large' rounded?: boolean + iconType?: AppIconType | null icon?: string - background?: string + background?: string | null + imageUrl?: string | null className?: string innerIcon?: React.ReactNode onClick?: () => void @@ -20,28 +24,34 @@ export type AppIconProps = { const AppIcon: FC = ({ size = 'medium', rounded = false, + iconType, icon, background, + imageUrl, className, innerIcon, onClick, }) => { - return ( - - {innerIcon || ((icon && icon !== '') ? : )} - + const wrapperClassName = classNames( + style.appIcon, + size !== 'medium' && style[size], + rounded && style.rounded, + className ?? '', + 'overflow-hidden', ) + + const isValidImageIcon = iconType === 'image' && imageUrl + + return + {isValidImageIcon + ? app icon + : (innerIcon || ((icon && icon !== '') ? : )) + } + } export default AppIcon diff --git a/web/app/components/base/app-icon/style.module.css b/web/app/components/base/app-icon/style.module.css index 20f76d40995daa..06a2478d41e380 100644 --- a/web/app/components/base/app-icon/style.module.css +++ b/web/app/components/base/app-icon/style.module.css @@ -1,5 +1,5 @@ .appIcon { - @apply flex items-center justify-center relative w-9 h-9 text-lg bg-teal-100 rounded-lg grow-0 shrink-0; + @apply flex items-center justify-center relative w-9 h-9 text-lg rounded-lg grow-0 shrink-0; } .appIcon.large { diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 675f58b53039c1..d57c79b571223d 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -83,25 +83,25 @@ const AudioBtn = ({ }[audioState] return ( -
+
diff --git a/web/app/components/base/badge.tsx b/web/app/components/base/badge.tsx index 3e5414fa2cba81..c3300a1e67e590 100644 --- a/web/app/components/base/badge.tsx +++ b/web/app/components/base/badge.tsx @@ -4,16 +4,19 @@ import cn from '@/utils/classnames' type BadgeProps = { className?: string text: string + uppercase?: boolean } const Badge = ({ className, text, + uppercase = true, }: BadgeProps) => { return (
diff --git a/web/app/components/base/block-input/index.tsx b/web/app/components/base/block-input/index.tsx index f2b6b5d6dc174c..79ff646bd12039 100644 --- a/web/app/components/base/block-input/index.tsx +++ b/web/app/components/base/block-input/index.tsx @@ -49,8 +49,6 @@ const BlockInput: FC = ({ setCurrentValue(value) }, [value]) - const isContentChanged = value !== currentValue - const contentEditableRef = useRef(null) const [isEditing, setIsEditing] = useState(false) useEffect(() => { diff --git a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx index 6bee44b8efe025..05f253290f021d 100644 --- a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx +++ b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx @@ -43,9 +43,12 @@ const ConfigPanel = () => { <>
{appData?.site.title}
diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index ab8b3648e7dae9..624cc53a183609 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -37,11 +37,20 @@ import type { import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { useToastContext } from '@/app/components/base/toast' import { changeLanguage } from '@/i18n/i18next-config' +import { useAppFavicon } from '@/hooks/use-app-favicon' export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo]) const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR(installedAppInfo ? null : 'appInfo', fetchAppInfo) + useAppFavicon({ + enable: !installedAppInfo, + icon_type: appInfo?.site.icon_type, + icon: appInfo?.site.icon, + icon_background: appInfo?.site.icon_background, + icon_url: appInfo?.site.icon_url, + }) + const appData = useMemo(() => { if (isInstalledApp) { const { id, app } = installedAppInfo! @@ -49,8 +58,10 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { app_id: id, site: { title: app.name, + icon_type: app.icon_type, icon: app.icon, icon_background: app.icon_background, + icon_url: app.icon_url, prompt_public: false, copyright: '', show_workflow_steps: true, diff --git a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx index f5898787111dd9..17a2751f111928 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx @@ -67,8 +67,10 @@ const Sidebar = () => {
{appData?.site.title} @@ -121,7 +123,6 @@ const Sidebar = () => { title={t('share.chat.deleteConversation.title')} content={t('share.chat.deleteConversation.content') || ''} isShow - onClose={handleCancelConfirm} onCancel={handleCancelConfirm} onConfirm={handleDelete} /> diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index 8ec5c0f3b23893..f54e7368262686 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -17,7 +17,7 @@ import { ThumbsDown, ThumbsUp, } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Log from '@/app/components/base/chat/chat/log' type OperationProps = { @@ -162,28 +162,34 @@ const Operation: FC = ({ { config?.supportFeedback && !localFeedback?.rating && onFeedback && !isOpeningStatement && (
- +
handleFeedback('like')} >
-
- + +
handleFeedback('dislike')} >
-
+
) } { config?.supportFeedback && localFeedback?.rating && onFeedback && !isOpeningStatement && ( - +
= ({ ) }
-
+ ) }
diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index 6d15010f85685c..c4578fab62b011 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -17,7 +17,7 @@ import { TransferMethod } from '../types' import { useChatWithHistoryContext } from '../chat-with-history/context' import type { Theme } from '../embedded-chatbot/theme/theme-context' import { CssTransform } from '../embedded-chatbot/theme/utils' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import VoiceInput from '@/app/components/base/voice-input' @@ -49,6 +49,7 @@ const ChatInput: FC = ({ const { t } = useTranslation() const { notify } = useContext(ToastContext) const [voiceInputShow, setVoiceInputShow] = useState(false) + const textAreaRef = useRef(null) const { files, onUpload, @@ -89,7 +90,7 @@ const ChatInput: FC = ({ } const handleKeyUp = (e: React.KeyboardEvent) => { - if (e.code === 'Enter') { + if (e.key === 'Enter') { e.preventDefault() // prevent send message when using input method enter if (!e.shiftKey && !isUseInputMethod.current) @@ -99,7 +100,7 @@ const ChatInput: FC = ({ const handleKeyDown = (e: React.KeyboardEvent) => { isUseInputMethod.current = e.nativeEvent.isComposing - if (e.code === 'Enter' && !e.shiftKey) { + if (e.key === 'Enter' && !e.shiftKey) { setQuery(query.replace(/\n$/, '')) e.preventDefault() } @@ -176,6 +177,7 @@ const ChatInput: FC = ({ ) }