forked from kyegomez/swarms
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request kyegomez#175 from kyegomez/memory
Added Qdrant memory class
- Loading branch information
Showing
5 changed files
with
247 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Qdrant Client Library | ||
|
||
## Overview | ||
|
||
The Qdrant Client Library is designed for interacting with the Qdrant vector database, allowing efficient storage and retrieval of high-dimensional vector data. It integrates with machine learning models for embedding and is particularly suited for search and recommendation systems. | ||
|
||
## Installation | ||
|
||
```python | ||
pip install qdrant-client sentence-transformers httpx | ||
``` | ||
|
||
## Class Definition: Qdrant | ||
|
||
```python | ||
class Qdrant: | ||
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True): | ||
... | ||
``` | ||
|
||
### Constructor Parameters | ||
|
||
| Parameter | Type | Description | Default Value | | ||
|-----------------|---------|--------------------------------------------------|-----------------------| | ||
| api_key | str | API key for authentication. | - | | ||
| host | str | Host address of the Qdrant server. | - | | ||
| port | int | Port number for the Qdrant server. | 6333 | | ||
| collection_name | str | Name of the collection to be used or created. | "qdrant" | | ||
| model_name | str | Name of the sentence transformer model. | "BAAI/bge-small-en-v1.5" | | ||
| https | bool | Flag to use HTTPS for connection. | True | | ||
|
||
### Methods | ||
|
||
#### `_load_embedding_model(model_name: str)` | ||
|
||
Loads the sentence embedding model. | ||
|
||
#### `_setup_collection()` | ||
|
||
Checks if the specified collection exists in Qdrant; if not, creates it. | ||
|
||
#### `add_vectors(docs: List[dict]) -> OperationResponse` | ||
|
||
Adds vectors to the Qdrant collection. | ||
|
||
#### `search_vectors(query: str, limit: int = 3) -> SearchResult` | ||
|
||
Searches the Qdrant collection for vectors similar to the query vector. | ||
|
||
## Usage Examples | ||
|
||
### Example 1: Setting Up the Qdrant Client | ||
|
||
```python | ||
from qdrant_client import Qdrant | ||
|
||
qdrant_client = Qdrant(api_key="your_api_key", host="localhost", port=6333) | ||
``` | ||
|
||
### Example 2: Adding Vectors to a Collection | ||
|
||
```python | ||
documents = [ | ||
{"page_content": "Sample text 1"}, | ||
{"page_content": "Sample text 2"} | ||
] | ||
|
||
operation_info = qdrant_client.add_vectors(documents) | ||
print(operation_info) | ||
``` | ||
|
||
### Example 3: Searching for Vectors | ||
|
||
```python | ||
search_result = qdrant_client.search_vectors("Sample search query") | ||
print(search_result) | ||
``` | ||
|
||
## Further Information | ||
|
||
Refer to the [Qdrant Documentation](https://qdrant.tech/docs) for more details on the Qdrant vector database. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from langchain.document_loaders import CSVLoader | ||
from swarms.memory import qdrant | ||
|
||
loader = CSVLoader(file_path="../document_parsing/aipg/aipg.csv", encoding='utf-8-sig') | ||
docs = loader.load() | ||
|
||
|
||
# Initialize the Qdrant instance | ||
# See qdrant documentation on how to run locally | ||
qdrant_client = qdrant.Qdrant(host ="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io", collection_name="qdrant", api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ") | ||
qdrant_client.add_vectors(docs) | ||
|
||
# Perform a search | ||
search_query = "Who is jojo" | ||
search_results = qdrant_client.search_vectors(search_query) | ||
print("Search Results:") | ||
for result in search_results: | ||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,110 @@ | ||
""" | ||
QDRANT MEMORY CLASS | ||
from typing import List | ||
from sentence_transformers import SentenceTransformer | ||
from httpx import RequestError | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http.models import Distance, VectorParams, PointStruct | ||
|
||
class Qdrant: | ||
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True): | ||
""" | ||
Qdrant class for managing collections and performing vector operations using QdrantClient. | ||
Attributes: | ||
client (QdrantClient): The Qdrant client for interacting with the Qdrant server. | ||
collection_name (str): Name of the collection to be managed in Qdrant. | ||
model (SentenceTransformer): The model used for generating sentence embeddings. | ||
""" | ||
Args: | ||
api_key (str): API key for authenticating with Qdrant. | ||
host (str): Host address of the Qdrant server. | ||
port (int): Port number of the Qdrant server. Defaults to 6333. | ||
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant". | ||
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5". | ||
https (bool): Flag to indicate if HTTPS should be used. Defaults to True. | ||
""" | ||
try: | ||
self.client = QdrantClient(url=host, port=port, api_key=api_key) | ||
self.collection_name = collection_name | ||
self._load_embedding_model(model_name) | ||
self._setup_collection() | ||
except RequestError as e: | ||
print(f"Error setting up QdrantClient: {e}") | ||
|
||
def _load_embedding_model(self, model_name: str): | ||
""" | ||
Loads the sentence embedding model specified by the model name. | ||
Args: | ||
model_name (str): The name of the model to load for generating embeddings. | ||
""" | ||
try: | ||
self.model = SentenceTransformer(model_name) | ||
except Exception as e: | ||
print(f"Error loading embedding model: {e}") | ||
|
||
def _setup_collection(self): | ||
try: | ||
exists = self.client.get_collection(self.collection_name) | ||
if exists: | ||
print(f"Collection '{self.collection_name}' already exists.") | ||
except Exception as e: | ||
self.client.create_collection( | ||
collection_name=self.collection_name, | ||
vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT), | ||
) | ||
print(f"Collection '{self.collection_name}' created.") | ||
|
||
def add_vectors(self, docs: List[dict]): | ||
""" | ||
Adds vector representations of documents to the Qdrant collection. | ||
Args: | ||
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key. | ||
Returns: | ||
OperationResponse or None: Returns the operation information if successful, otherwise None. | ||
""" | ||
points = [] | ||
for i, doc in enumerate(docs): | ||
try: | ||
if 'page_content' in doc: | ||
embedding = self.model.encode(doc['page_content'], normalize_embeddings=True) | ||
points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']})) | ||
else: | ||
print(f"Document at index {i} is missing 'page_content' key") | ||
except Exception as e: | ||
print(f"Error processing document at index {i}: {e}") | ||
|
||
try: | ||
operation_info = self.client.upsert( | ||
collection_name=self.collection_name, | ||
wait=True, | ||
points=points, | ||
) | ||
return operation_info | ||
except Exception as e: | ||
print(f"Error adding vectors: {e}") | ||
return None | ||
|
||
def search_vectors(self, query: str, limit: int = 3): | ||
""" | ||
Searches the collection for vectors similar to the query vector. | ||
Args: | ||
query (str): The query string to be converted into a vector and used for searching. | ||
limit (int): The number of search results to return. Defaults to 3. | ||
Returns: | ||
SearchResult or None: Returns the search results if successful, otherwise None. | ||
""" | ||
try: | ||
query_vector = self.model.encode(query, normalize_embeddings=True) | ||
search_result = self.client.search( | ||
collection_name=self.collection_name, | ||
query_vector=query_vector, | ||
limit=limit | ||
) | ||
return search_result | ||
except Exception as e: | ||
print(f"Error searching vectors: {e}") | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from unittest.mock import Mock, patch | ||
|
||
from swarms.memory.qdrant import Qdrant | ||
|
||
|
||
@pytest.fixture | ||
def mock_qdrant_client(): | ||
with patch('your_module.QdrantClient') as MockQdrantClient: | ||
yield MockQdrantClient() | ||
|
||
@pytest.fixture | ||
def mock_sentence_transformer(): | ||
with patch('sentence_transformers.SentenceTransformer') as MockSentenceTransformer: | ||
yield MockSentenceTransformer() | ||
|
||
@pytest.fixture | ||
def qdrant_client(mock_qdrant_client, mock_sentence_transformer): | ||
client = Qdrant(api_key="your_api_key", host="your_host") | ||
yield client | ||
|
||
def test_qdrant_init(qdrant_client, mock_qdrant_client): | ||
assert qdrant_client.client is not None | ||
|
||
def test_load_embedding_model(qdrant_client, mock_sentence_transformer): | ||
qdrant_client._load_embedding_model("model_name") | ||
mock_sentence_transformer.assert_called_once_with("model_name") | ||
|
||
def test_setup_collection(qdrant_client, mock_qdrant_client): | ||
qdrant_client._setup_collection() | ||
mock_qdrant_client.get_collection.assert_called_once_with(qdrant_client.collection_name) | ||
|
||
def test_add_vectors(qdrant_client, mock_qdrant_client): | ||
mock_doc = Mock(page_content="Sample text") | ||
qdrant_client.add_vectors([mock_doc]) | ||
mock_qdrant_client.upsert.assert_called_once() | ||
|
||
def test_search_vectors(qdrant_client, mock_qdrant_client): | ||
qdrant_client.search_vectors("test query") | ||
mock_qdrant_client.search.assert_called_once() |