-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
232 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import AsyncIterator, Iterator, List, Any | ||
|
||
from langchain_core.document_loaders import BaseLoader | ||
from langchain_core.documents import Document | ||
|
||
from pyensign.ensign import Ensign | ||
|
||
|
||
class EnsignLoader(BaseLoader): | ||
""" | ||
A loader that reads data from an Ensign topic. | ||
""" | ||
|
||
def __init__(self, topic: str, content_field: str = "", **kwargs): | ||
""" | ||
Initialize the loader with a topic name. Loading data from Ensign requires a | ||
topic name and Ensign credentials. Credentials can be passed in as keyword | ||
arguments or set as the ENSIGN_CLIENT_ID and ENSIGN_CLIENT_SECRET environment | ||
variables. | ||
Args: | ||
topic: The name of the topic to read from. | ||
""" | ||
|
||
if len(topic) == 0: | ||
raise ValueError("Topic name is required") | ||
|
||
self.topic = topic | ||
self.content_field = content_field | ||
self.ensign = Ensign(**kwargs) | ||
|
||
def _convert_to_document(self, event: Any) -> Document: | ||
""" | ||
Convert an Ensign event to a Document. | ||
""" | ||
|
||
return Document( | ||
page_content=event.decode_to_str(text_field=self.content_field), | ||
metadata=event.meta, | ||
) | ||
|
||
def load(self) -> Iterator[Document]: | ||
""" | ||
Loads all documents from Ensign into memory. | ||
Note: In production, this should be used with caution since it loads all | ||
documents into memory at once. | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
async def aload(self) -> List[Document]: | ||
""" | ||
Load all documents from Ensign into memory asynchronously. | ||
TODO: Prevent SQL injection with the topic name. | ||
""" | ||
|
||
cursor = await self.ensign.query("SELECT * FROM {}".format(self.topic)) | ||
events = await cursor.fetchall() | ||
return [self._convert_to_document(event) for event in events] | ||
|
||
def lazy_load(self) -> Iterator[Document]: | ||
""" | ||
Load documents from Ensign one by one lazily. | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
async def alazy_load( | ||
self, | ||
) -> AsyncIterator[Document]: # <-- Does not take any arguments | ||
""" | ||
Load documents from Ensign one by one lazily. This is an async generator. | ||
""" | ||
|
||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import pytest | ||
from asyncmock import patch | ||
|
||
from pyensign.events import Event | ||
from pyensign.ml.loader import EnsignLoader | ||
|
||
|
||
class TestEnsignLoader: | ||
""" | ||
Tests for the EnsignLoader class. | ||
""" | ||
|
||
@pytest.mark.parametrize( | ||
"topic, kwargs", | ||
[ | ||
("", {"client_id": "my_client_id", "client_secret": "my_client_secret"}), | ||
("otters", {"client_id": "my_client_id"}), | ||
], | ||
) | ||
def test_init_errors(self, topic, kwargs): | ||
""" | ||
Test that the loader raises the correct error when initialized with invalid | ||
arguments. | ||
""" | ||
|
||
with pytest.raises(ValueError): | ||
EnsignLoader(topic, **kwargs) | ||
|
||
@patch("pyensign.ensign.Ensign.query") | ||
@pytest.mark.parametrize( | ||
"events", | ||
[ | ||
([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]), | ||
( | ||
[ | ||
Event( | ||
b'{"content": "hello"}', | ||
mimetype="application/json", | ||
meta={"page": "42"}, | ||
), | ||
Event( | ||
b"<h1>Hello, world!</h1>", | ||
mimetype="text/html", | ||
meta={"page": "23"}, | ||
), | ||
] | ||
), | ||
], | ||
) | ||
async def test_aload(self, mock_query, events): | ||
""" | ||
Test loading a batch of documents asynchronously from an Ensign topic | ||
""" | ||
|
||
loader = EnsignLoader( | ||
"otters", | ||
client_id="my_client_id", | ||
client_secret="my_client_secret", | ||
) | ||
mock_query.return_value.fetchall.return_value = events | ||
documents = await loader.aload() | ||
args, _ = mock_query.call_args | ||
assert args[0] == "SELECT * FROM otters" | ||
assert len(documents) == len(events) | ||
for i, document in enumerate(documents): | ||
assert document.page_content == events[i].data.decode() | ||
assert document.metadata == events[i].meta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters