From 1d31b3272376046525ed2c29a91808a36d97950d Mon Sep 17 00:00:00 2001 From: Wang Ying Date: Sun, 15 Sep 2024 16:00:26 +0800 Subject: [PATCH] add KnowledgeBaseClient and test cases. --- sdks/python-client/dify_client/client.py | 281 ++++++++++++++++++++++- sdks/python-client/setup.py | 2 +- sdks/python-client/tests/test_client.py | 149 +++++++++++- 3 files changed, 429 insertions(+), 3 deletions(-) diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index b6b0ced2ce7628..2be079bdf381ce 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,3 +1,4 @@ +import json import requests @@ -133,4 +134,282 @@ def stop(self, task_id, user): return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) def get_result(self, workflow_run_id): - return self._send_request("GET", f"/workflows/run/{workflow_run_id}") \ No newline at end of file + return self._send_request("GET", f"/workflows/run/{workflow_run_id}") + + + +class KnowledgeBaseClient(DifyClient): + + def __init__(self, api_key, base_url: str = 'https://api.dify.ai/v1', dataset_id: str = None): + """ + Construct a KnowledgeBaseClient object. + + Args: + api_key (str): API key of Dify. + base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. + dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to + create a new dataset. or list datasets. otherwise you need to set this. + """ + super().__init__( + api_key=api_key, + base_url=base_url + ) + self.dataset_id = dataset_id + + def _get_dataset_id(self): + if self.dataset_id is None: + raise ValueError("dataset_id is not set") + return self.dataset_id + + def create_dataset(self, name: str, **kwargs): + return self._send_request('POST', '/datasets', {'name': name}, **kwargs) + + def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): + return self._send_request('GET', f'/datasets?page={page}&limit={page_size}', **kwargs) + + def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + """ + Create a document by text. + + :param name: Name of the document + :param text: Text content of the document + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + data = { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'mode': 'automatic' + }, + 'name': name, + 'text': text + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" + return self._send_request("POST", url, json=data, **kwargs) + + def update_document_by_text(self, document_id, name, text, extra_params: dict = None, **kwargs): + """ + Update a document by text. + + :param document_id: ID of the document + :param name: Name of the document + :param text: Text content of the document + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + data = { + 'name': name, + 'text': text + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" + return self._send_request("POST", url, json=data, **kwargs) + + def create_document_by_file(self, file_path, original_document_id=None, extra_params: dict = None): + """ + Create a document by file. + + :param file_path: Path to the file + :param original_document_id: pass this ID if you want to replace the original document (optional) + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + files = {"file": open(file_path, "rb")} + data = { + 'process_rule': { + 'mode': 'automatic' + }, + 'indexing_technique': 'high_quality' + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + if original_document_id is not None: + data['original_document_id'] = original_document_id + url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) + + def update_document_by_file(self, document_id, file_path, extra_params: dict = None): + """ + Update a document by file. + + :param document_id: ID of the document + :param file_path: Path to the file + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: + """ + files = {"file": open(file_path, "rb")} + data = {} + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) + + def batch_indexing_status(self, batch_id: str, **kwargs): + """ + Get the status of the batch indexing. + + :param batch_id: ID of the batch uploading + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" + return self._send_request("GET", url, **kwargs) + + def delete_dataset(self): + """ + Delete this dataset. + + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}" + return self._send_request("DELETE", url) + + def delete_document(self, document_id): + """ + Delete a document. + + :param document_id: ID of the document + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" + return self._send_request("DELETE", url) + + def list_documents(self, page: int = None, page_size: int = None, keyword: str = None, **kwargs): + """ + Get a list of documents in this dataset. + + :return: Response from the API + """ + params = {} + if page is not None: + params['page'] = page + if page_size is not None: + params['limit'] = page_size + if keyword is not None: + params['keyword'] = keyword + url = f"/datasets/{self._get_dataset_id()}/documents" + return self._send_request("GET", url, params=params, **kwargs) + + def add_segments(self, document_id, segments, **kwargs): + """ + Add segments to a document. + + :param document_id: ID of the document + :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] + :return: Response from the API + """ + data = {"segments": segments} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" + return self._send_request("POST", url, json=data, **kwargs) + + def query_segments(self, document_id, keyword: str = None, status: str = None, **kwargs): + """ + Query segments in this document. + + :param document_id: ID of the document + :param keyword: query keyword, optional + :param status: status of the segment, optional, e.g. completed + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" + params = {} + if keyword is not None: + params['keyword'] = keyword + if status is not None: + params['status'] = status + if "params" in kwargs: + params.update(kwargs["params"]) + return self._send_request("GET", url, params=params, **kwargs) + + def delete_document_segment(self, document_id, segment_id): + """ + Delete a segment from a document. + + :param document_id: ID of the document + :param segment_id: ID of the segment + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("DELETE", url) + + def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): + """ + Update a segment in a document. + + :param document_id: ID of the document + :param segment_id: ID of the segment + :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} + :return: Response from the API + """ + data = {"segment": segment_data} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("POST", url, json=data, **kwargs) diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index e7253f7391823f..bb8ca46d97663b 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -5,7 +5,7 @@ setup( name="dify-client", - version="0.1.11", + version="0.1.12", author="Dify", author_email="hello@dify.ai", description="A package for interacting with the Dify Service-API", diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 5259d082cafa97..301e733b6b008a 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -1,10 +1,157 @@ import os +import time import unittest -from dify_client.client import ChatClient, CompletionClient, DifyClient +from dify_client.client import ChatClient, CompletionClient, DifyClient, KnowledgeBaseClient API_KEY = os.environ.get("API_KEY") APP_ID = os.environ.get("APP_ID") +API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") +FILE_PATH_BASE = os.path.dirname(__file__) + + +class TestKnowledgeBaseClient(unittest.TestCase): + def setUp(self): + self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) + self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) + self.dataset_id = None + self.document_id = None + self.segment_id = None + self.batch_id = None + + def _get_dataset_kb_client(self): + self.assertIsNotNone(self.dataset_id) + return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id) + + def test_001_create_dataset(self): + response = self.knowledge_base_client.create_dataset(name="test_dataset") + data = response.json() + self.assertIn("id", data) + self.dataset_id = data["id"] + self.assertEqual("test_dataset", data["name"]) + + # the following tests require to be executed in order because they use + # the dataset/document/segment ids from the previous test + self._test_002_list_datasets() + self._test_003_create_document_by_text() + time.sleep(1) + self._test_004_update_document_by_text() + # self._test_005_batch_indexing_status() + time.sleep(1) + self._test_006_update_document_by_file() + time.sleep(1) + self._test_007_list_documents() + self._test_008_delete_document() + self._test_009_create_document_by_file() + time.sleep(1) + self._test_010_add_segments() + self._test_011_query_segments() + self._test_012_update_document_segment() + self._test_013_delete_document_segment() + self._test_014_delete_dataset() + + def _test_002_list_datasets(self): + response = self.knowledge_base_client.list_datasets() + data = response.json() + self.assertIn("data", data) + self.assertIn("total", data) + + def _test_003_create_document_by_text(self): + client = self._get_dataset_kb_client() + response = client.create_document_by_text("test_document", "test_text") + data = response.json() + self.assertIn("document", data) + self.document_id = data["document"]["id"] + self.batch_id = data["batch"] + + def _test_004_update_document_by_text(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") + data = response.json() + self.assertIn("document", data) + self.assertIn("batch", data) + self.batch_id = data["batch"] + + def _test_005_batch_indexing_status(self): + client = self._get_dataset_kb_client() + response = client.batch_indexing_status(self.batch_id) + data = response.json() + self.assertEqual(response.status_code, 200) + + def _test_006_update_document_by_file(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) + data = response.json() + self.assertIn("document", data) + self.assertIn("batch", data) + self.batch_id = data["batch"] + + def _test_007_list_documents(self): + client = self._get_dataset_kb_client() + response = client.list_documents() + data = response.json() + self.assertIn("data", data) + + def _test_008_delete_document(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.delete_document(self.document_id) + data = response.json() + self.assertIn("result", data) + self.assertEqual("success", data["result"]) + + def _test_009_create_document_by_file(self): + client = self._get_dataset_kb_client() + response = client.create_document_by_file(self.README_FILE_PATH) + data = response.json() + self.assertIn("document", data) + self.document_id = data["document"]["id"] + self.batch_id = data["batch"] + + def _test_010_add_segments(self): + client = self._get_dataset_kb_client() + response = client.add_segments(self.document_id, [ + {"content": "test text segment 1"} + ]) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + segment = data["data"][0] + self.segment_id = segment["id"] + + def _test_011_query_segments(self): + client = self._get_dataset_kb_client() + response = client.query_segments(self.document_id) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + + def _test_012_update_document_segment(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.segment_id) + response = client.update_document_segment(self.document_id, self.segment_id, + {"content": "test text segment 1 updated"} + ) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + segment = data["data"] + self.assertEqual("test text segment 1 updated", segment["content"]) + + def _test_013_delete_document_segment(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.segment_id) + response = client.delete_document_segment(self.document_id, self.segment_id) + data = response.json() + self.assertIn("result", data) + self.assertEqual("success", data["result"]) + + def _test_014_delete_dataset(self): + client = self._get_dataset_kb_client() + response = client.delete_dataset() + self.assertEqual(204, response.status_code) class TestChatClient(unittest.TestCase):