Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Qdrant memory class #175

Merged
merged 7 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions docs/swarms/memory/qdrant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Qdrant Client Library

## Overview

The Qdrant Client Library is designed for interacting with the Qdrant vector database, allowing efficient storage and retrieval of high-dimensional vector data. It integrates with machine learning models for embedding and is particularly suited for search and recommendation systems.

## Installation

```python
pip install qdrant-client sentence-transformers httpx
```

## Class Definition: Qdrant

```python
class Qdrant:
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True):
...
```

### Constructor Parameters

| Parameter | Type | Description | Default Value |
|-----------------|---------|--------------------------------------------------|-----------------------|
| api_key | str | API key for authentication. | - |
| host | str | Host address of the Qdrant server. | - |
| port | int | Port number for the Qdrant server. | 6333 |
| collection_name | str | Name of the collection to be used or created. | "qdrant" |
| model_name | str | Name of the sentence transformer model. | "BAAI/bge-small-en-v1.5" |
| https | bool | Flag to use HTTPS for connection. | True |

### Methods

#### `_load_embedding_model(model_name: str)`

Loads the sentence embedding model.

#### `_setup_collection()`

Checks if the specified collection exists in Qdrant; if not, creates it.

#### `add_vectors(docs: List[dict]) -> OperationResponse`

Adds vectors to the Qdrant collection.

#### `search_vectors(query: str, limit: int = 3) -> SearchResult`

Searches the Qdrant collection for vectors similar to the query vector.

## Usage Examples

### Example 1: Setting Up the Qdrant Client

```python
from qdrant_client import Qdrant

qdrant_client = Qdrant(api_key="your_api_key", host="localhost", port=6333)
```

### Example 2: Adding Vectors to a Collection

```python
documents = [
{"page_content": "Sample text 1"},
{"page_content": "Sample text 2"}
]

operation_info = qdrant_client.add_vectors(documents)
print(operation_info)
```

### Example 3: Searching for Vectors

```python
search_result = qdrant_client.search_vectors("Sample search query")
print(search_result)
```

## Further Information

Refer to the [Qdrant Documentation](https://qdrant.tech/docs) for more details on the Qdrant vector database.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ nav:
- swarms.memory:
- PineconeVectorStoreStore: "swarms/memory/pinecone.md"
- PGVectorStore: "swarms/memory/pg.md"
- Qdrant: "swarms/memory/qdrant.md"
- Guides:
- Overview: "examples/index.md"
- Agents:
Expand Down
18 changes: 18 additions & 0 deletions playground/memory/qdrant/usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from langchain.document_loaders import CSVLoader
from swarms.memory import qdrant

loader = CSVLoader(file_path="../document_parsing/aipg/aipg.csv", encoding='utf-8-sig')
docs = loader.load()


# Initialize the Qdrant instance
# See qdrant documentation on how to run locally
qdrant_client = qdrant.Qdrant(host ="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io", collection_name="qdrant", api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ")
qdrant_client.add_vectors(docs)

# Perform a search
search_query = "Who is jojo"
search_results = qdrant_client.search_vectors(search_query)
print("Search Results:")
for result in search_results:
print(result)
110 changes: 107 additions & 3 deletions swarms/memory/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,110 @@
"""
QDRANT MEMORY CLASS
from typing import List
from sentence_transformers import SentenceTransformer
from httpx import RequestError
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct

class Qdrant:
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True):
"""
Qdrant class for managing collections and performing vector operations using QdrantClient.

Attributes:
client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
collection_name (str): Name of the collection to be managed in Qdrant.
model (SentenceTransformer): The model used for generating sentence embeddings.

"""
Args:
api_key (str): API key for authenticating with Qdrant.
host (str): Host address of the Qdrant server.
port (int): Port number of the Qdrant server. Defaults to 6333.
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
"""
try:
self.client = QdrantClient(url=host, port=port, api_key=api_key)
self.collection_name = collection_name
self._load_embedding_model(model_name)
self._setup_collection()
except RequestError as e:
print(f"Error setting up QdrantClient: {e}")

def _load_embedding_model(self, model_name: str):
"""
Loads the sentence embedding model specified by the model name.

Args:
model_name (str): The name of the model to load for generating embeddings.
"""
try:
self.model = SentenceTransformer(model_name)
except Exception as e:
print(f"Error loading embedding model: {e}")

def _setup_collection(self):
try:
exists = self.client.get_collection(self.collection_name)
if exists:
print(f"Collection '{self.collection_name}' already exists.")
except Exception as e:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT),
)
print(f"Collection '{self.collection_name}' created.")

def add_vectors(self, docs: List[dict]):
"""
Adds vector representations of documents to the Qdrant collection.

Args:
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.

Returns:
OperationResponse or None: Returns the operation information if successful, otherwise None.
"""
points = []
for i, doc in enumerate(docs):
try:
if 'page_content' in doc:
embedding = self.model.encode(doc['page_content'], normalize_embeddings=True)
points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']}))
else:
print(f"Document at index {i} is missing 'page_content' key")
except Exception as e:
print(f"Error processing document at index {i}: {e}")

try:
operation_info = self.client.upsert(
collection_name=self.collection_name,
wait=True,
points=points,
)
return operation_info
except Exception as e:
print(f"Error adding vectors: {e}")
return None

def search_vectors(self, query: str, limit: int = 3):
"""
Searches the collection for vectors similar to the query vector.

Args:
query (str): The query string to be converted into a vector and used for searching.
limit (int): The number of search results to return. Defaults to 3.

Returns:
SearchResult or None: Returns the search results if successful, otherwise None.
"""
try:
query_vector = self.model.encode(query, normalize_embeddings=True)
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit
)
return search_result
except Exception as e:
print(f"Error searching vectors: {e}")
return None
40 changes: 40 additions & 0 deletions tests/memory/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from unittest.mock import Mock, patch

from swarms.memory.qdrant import Qdrant


@pytest.fixture
def mock_qdrant_client():
with patch('your_module.QdrantClient') as MockQdrantClient:
yield MockQdrantClient()

@pytest.fixture
def mock_sentence_transformer():
with patch('sentence_transformers.SentenceTransformer') as MockSentenceTransformer:
yield MockSentenceTransformer()

@pytest.fixture
def qdrant_client(mock_qdrant_client, mock_sentence_transformer):
client = Qdrant(api_key="your_api_key", host="your_host")
yield client

def test_qdrant_init(qdrant_client, mock_qdrant_client):
assert qdrant_client.client is not None

def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
qdrant_client._load_embedding_model("model_name")
mock_sentence_transformer.assert_called_once_with("model_name")

def test_setup_collection(qdrant_client, mock_qdrant_client):
qdrant_client._setup_collection()
mock_qdrant_client.get_collection.assert_called_once_with(qdrant_client.collection_name)

def test_add_vectors(qdrant_client, mock_qdrant_client):
mock_doc = Mock(page_content="Sample text")
qdrant_client.add_vectors([mock_doc])
mock_qdrant_client.upsert.assert_called_once()

def test_search_vectors(qdrant_client, mock_qdrant_client):
qdrant_client.search_vectors("test query")
mock_qdrant_client.search.assert_called_once()
Loading