diff --git a/ignore-words.txt b/ignore-words.txt index 877e68a84..c9e478871 100644 --- a/ignore-words.txt +++ b/ignore-words.txt @@ -1,2 +1,3 @@ # ignore-words.txt -selectin \ No newline at end of file +selectin +NotIn \ No newline at end of file diff --git a/pandasai/ee/vectorstores/lanceDB.py b/pandasai/ee/vectorstores/lanceDB.py index daf79b812..2c9047d20 100644 --- a/pandasai/ee/vectorstores/lanceDB.py +++ b/pandasai/ee/vectorstores/lanceDB.py @@ -153,11 +153,8 @@ def add_question_answer( else: data = {"id": ids, "qa": qa_str, "metadata": metadatas} - print("data: ", data) df = pd.DataFrame(data) - print("df: ", df) self._qa_table.add(df) - print("Len of table: ", self._qa_table.head()) return ids @@ -244,7 +241,6 @@ def update_question_answer( "qa": str(qa_str[i]), "metadata": metadatas[i], } - print("updated values: ", updated_values, ids[i]) self._qa_table.update(values=updated_values, where=f"id = '{ids[i]}'") return ids @@ -409,6 +405,8 @@ def _filter_docs_based_on_distance( Returns: _type_: _description_ """ + if not documents: + return documents relevant_column = list( documents[0].keys() - {"id", "vector", "metadata", "_distance"} ) diff --git a/tests/unit_tests/vectorstores/test_lancedb.py b/tests/unit_tests/vectorstores/test_lancedb.py index faa2ec6be..cc3810de6 100644 --- a/tests/unit_tests/vectorstores/test_lancedb.py +++ b/tests/unit_tests/vectorstores/test_lancedb.py @@ -1,3 +1,5 @@ +import os +import shutil import unittest from unittest.mock import MagicMock @@ -12,11 +14,11 @@ def setUp(self): self.vector_store._format_qa = MagicMock( side_effect=lambda q, c: f"Q: {q}\nA: {c}" ) - # self.vector_store._embedding_function = MagicMock( - # return_value=[[1.0, 2.0, 3.0]] * 2 - # ) - self.vector_store._qa_table - self.vector_store._docs_table + + def tearDown(self) -> None: + path = "/tmp/lancedb" + if os.path.exists(path): + shutil.rmtree(path) def test_constructor_default_parameters(self): self.assertEqual(self.vector_store._max_samples, 1) @@ -46,9 +48,9 @@ def test_add_question_answer_with_ids(self): inserted_ids = self.vector_store.add_question_answer( ["What is LanceDB?", "How does it work?"], ["print('Hello')", "for i in range(10): print(i)"], - ["test_id_1", "test_id_2"], + ["test_id_11", "test_id_12"], ) - assert inserted_ids == ["test_id_1", "test_id_2"] + assert inserted_ids == ["test_id_11", "test_id_12"] def test_add_question_answer_different_dimensions(self): with self.assertRaises(ValueError): @@ -92,17 +94,22 @@ def test_delete_docs(self): self.assertEqual(deleted_docs, True) def test_get_relevant_question_answers(self): + self.vector_store.add_question_answer( + ["What is LanceDB?", "How does it work?"], + ["print('Hello')", "for i in range(10): print(i)"], + ["test_id_11", "test_id_12"], + ) result = self.vector_store.get_relevant_question_answers( "What is LanceDB?", k=2 ) - print("result: ", result) + self.assertEqual( result, { "documents": [ [ "Q: What is LanceDB?\nA: print('Hello')", - "Q: What is LanceDB?\nA: print('Hello')", + "Q: How does it work?\nA: for i in range(10): print(i)", ] ], "metadatas": [["None", "None"]], @@ -110,26 +117,45 @@ def test_get_relevant_question_answers(self): ) def test_get_relevant_question_answers_by_ids(self): - result = self.vector_store.get_relevant_question_answers_by_id(["test_id_1"]) + self.vector_store.add_question_answer( + ["What is LanceDB?", "How does it work?"], + ["print('Hello')", "for i in range(10): print(i)"], + ["test_id_11", "test_id_12"], + ) + result = self.vector_store.get_relevant_question_answers_by_id(["test_id_11"]) + print(result) self.assertEqual( result, - [[{"metadata": "None", "qa": "Q: What is LanceDB?\nA: print('Hello')"}]], + [ + [ + { + "metadata": "None", + "qa": "Q: What is LanceDB?\nA: print('Hello')", + } + ] + ], ) def test_get_relevant_docs(self): + self.vector_store.add_docs( + ["Document 1", "Document 2", "Document 3"], + ["test_id_1", "test_id_2", "test_id_3"], + ) result = self.vector_store.get_relevant_docs("What is LanceDB?", k=3) - print("result:", result) self.assertEqual( result, { - "documents": [["Document 1", "Document 1", "Document 2"]], + "documents": [["Document 1", "Document 2", "Document 3"]], "metadatas": [["None", "None", "None"]], }, ) def test_get_relevant_docs_by_ids(self): + self.vector_store.add_docs( + ["Document 1", "Document 2", "Document 3"], + ["test_id_1", "test_id_2", "test_id_3"], + ) result = self.vector_store.get_relevant_docs_by_id(["test_id_1"]) - print("Result docs ids: ", result) self.assertEqual(result, [[{"doc": "Document 1", "metadata": "None"}]])