Skip to content

Commit

Permalink
Add langchain document loader
Browse files Browse the repository at this point in the history
  • Loading branch information
pdeziel committed Jun 7, 2024
1 parent 16cfdaf commit 63aa96c
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 3 deletions.
28 changes: 28 additions & 0 deletions pyensign/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,34 @@ def decode(self, decoder=None):
else:
return self.data

def decode_to_str(self, text_field=None, check_valid=False):
"""
Decode the event data into a string if possible. If the mimetype is not
supported for string decoding, a ValueError is raised.
Parameters
----------
text_field : str (optional)
The field to extract from the decoded object. If not provided, the entire
decoded object is returned.
check_valid : bool (optional, default: False)
If True, validate that the event data can be decoded according to the
mimetype (e.g. to check that the payload is valid JSON).
"""

if check_valid or text_field:
obj = self.decode()

if self.mimetype == mtype.ApplicationJSON and text_field:
if text_field not in obj:
raise ValueError("text field not found in decoded object")
val = obj[text_field]
if not isinstance(val, str):
raise ValueError("text field does not have a string value")
return val
else:
return self.data.decode()

def __repr__(self):
repr = "Event("
repr += "data={}, ".format(self.data)
Expand Down
75 changes: 75 additions & 0 deletions pyensign/ml/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import AsyncIterator, Iterator, List, Any

from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document

from pyensign.ensign import Ensign


class EnsignLoader(BaseLoader):
"""
A loader that reads data from an Ensign topic.
"""

def __init__(self, topic: str, content_field: str = "", **kwargs):
"""
Initialize the loader with a topic name. Loading data from Ensign requires a
topic name and Ensign credentials. Credentials can be passed in as keyword
arguments or set as the ENSIGN_CLIENT_ID and ENSIGN_CLIENT_SECRET environment
variables.
Args:
topic: The name of the topic to read from.
"""

if len(topic) == 0:
raise ValueError("Topic name is required")

self.topic = topic
self.content_field = content_field
self.ensign = Ensign(**kwargs)

def _convert_to_document(self, event: Any) -> Document:
"""
Convert an Ensign event to a Document.
"""

return Document(
page_content=event.decode_to_str(text_field=self.content_field),
metadata=event.meta,
)

def load(self) -> Iterator[Document]:
"""
Loads all documents from Ensign into memory.
Note: In production, this should be used with caution since it loads all
documents into memory at once.
"""

raise NotImplementedError

async def aload(self) -> List[Document]:
"""
Load all documents from Ensign into memory asynchronously.
TODO: Prevent SQL injection with the topic name.
"""

cursor = await self.ensign.query("SELECT * FROM {}".format(self.topic))
events = await cursor.fetchall()
return [self._convert_to_document(event) for event in events]

def lazy_load(self) -> Iterator[Document]:
"""
Load documents from Ensign one by one lazily.
"""

raise NotImplementedError

async def alazy_load(
self,
) -> AsyncIterator[Document]: # <-- Does not take any arguments
"""
Load documents from Ensign one by one lazily. This is an async generator.
"""

raise NotImplementedError
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def get_description_type(path=PKG_DESCRIBE):
"download_url": "{}/tarball/v{}".format(REPOSITORY, get_version()),
"packages": find_packages(where=PROJECT, exclude=EXCLUDES),
"install_requires": list(get_requires()),
"extras_require": {
"ml": ["pandas==2.1.2"],
},
"extras_require": {"ml": ["pandas==2.1.2", "langchain==0.2.3"]},
"classifiers": CLASSIFIERS,
"keywords": KEYWORDS,
"zip_safe": False,
Expand Down
67 changes: 67 additions & 0 deletions tests/pyensign/ml/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
from asyncmock import patch

from pyensign.events import Event
from pyensign.ml.loader import EnsignLoader


class TestEnsignLoader:
"""
Tests for the EnsignLoader class.
"""

@pytest.mark.parametrize(
"topic, kwargs",
[
("", {"client_id": "my_client_id", "client_secret": "my_client_secret"}),
("otters", {"client_id": "my_client_id"}),
],
)
def test_init_errors(self, topic, kwargs):
"""
Test that the loader raises the correct error when initialized with invalid
arguments.
"""

with pytest.raises(ValueError):
EnsignLoader(topic, **kwargs)

@patch("pyensign.ensign.Ensign.query")
@pytest.mark.parametrize(
"events",
[
([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]),
(
[
Event(
b'{"content": "hello"}',
mimetype="application/json",
meta={"page": "42"},
),
Event(
b"<h1>Hello, world!</h1>",
mimetype="text/html",
meta={"page": "23"},
),
]
),
],
)
async def test_aload(self, mock_query, events):
"""
Test loading a batch of documents asynchronously from an Ensign topic
"""

loader = EnsignLoader(
"otters",
client_id="my_client_id",
client_secret="my_client_secret",
)
mock_query.return_value.fetchall.return_value = events
documents = await loader.aload()
args, _ = mock_query.call_args
assert args[0] == "SELECT * FROM otters"
assert len(documents) == len(events)
for i, document in enumerate(documents):
assert document.page_content == events[i].data.decode()
assert document.metadata == events[i].meta
61 changes: 61 additions & 0 deletions tests/pyensign/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,67 @@ def test_event(
created=event.created,
)

@pytest.mark.parametrize(
"event, text_field, expected",
[
(Event(data=b"test", mimetype=mt.TextPlain), None, "test"),
(
Event(data=b'{"document": "hello world"}', mimetype=mt.ApplicationJSON),
"document",
"hello world",
),
(
Event(data=b'{"document": "hello world"}', mimetype=mt.ApplicationJSON),
None,
'{"document": "hello world"}',
),
(
Event(data=b"<h1>Hello World</h1>", mimetype=mt.TextHTML),
None,
"<h1>Hello World</h1>",
),
],
)
def test_decode_to_str(self, event, text_field, expected):
"""
Test that the decode_to_str method always returns a string and handles any
mimetype.
"""

assert event.decode_to_str(text_field=text_field) == expected

@pytest.mark.parametrize(
"event, text_field, check_valid, exception",
[
(
Event(data=b'{"document": "hello world"}', mimetype=mt.ApplicationJSON),
"does_not_exist",
False,
ValueError,
),
(
Event(data=b'{"document": 42}', mimetype=mt.ApplicationJSON),
"document",
False,
ValueError,
),
(
Event(data=b"{document: ", mimetype=mt.ApplicationJSON),
None,
True,
ValueError,
),
],
)
def test_decode_to_str_error(self, event, text_field, check_valid, exception):
"""
Test that the decode_to_str method raises the correct error when the event
payload cannot be decoded to a string.
"""

with pytest.raises(exception):
event.decode_to_str(text_field=text_field, check_valid=check_valid)

def test_modify_event(self):
"""
Ensure events can be modified after creation and updates are present in the
Expand Down

0 comments on commit 63aa96c

Please sign in to comment.