From e9d21a52b7736e98e6d83ce435065a34700aa6b7 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Thu, 28 Mar 2024 21:05:00 +0500 Subject: [PATCH] fix: sqlite connector override base equals (#1071) function for sqlite connector --- pandasai/connectors/sql.py | 15 +++++++++++++ tests/unit_tests/connectors/test_sqlite.py | 26 ++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index f2fafbe8e..fe041e470 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -539,6 +539,21 @@ def __repr__(self): f"table={self.config.table}>" ) + def equals(self, other): + if isinstance(other, self.__class__): + print(self.config.database) + print(other.config.database) + return ( + self.config.dialect, + self.config.driver, + self.config.database, + ) == ( + other.config.dialect, + other.config.driver, + other.config.database, + ) + return False + class MySQLConnector(SQLConnector): """ diff --git a/tests/unit_tests/connectors/test_sqlite.py b/tests/unit_tests/connectors/test_sqlite.py index eff82745e..1ad38165f 100644 --- a/tests/unit_tests/connectors/test_sqlite.py +++ b/tests/unit_tests/connectors/test_sqlite.py @@ -83,3 +83,29 @@ def test_fallback_name_property(self): # Test fallback_name property fallback_name = self.connector.fallback_name self.assertEqual(fallback_name, "yourtable") + + @patch("pandasai.connectors.SqliteConnector._init_connection") + def test_two_connector_equal(self, mock_init_connection): + conn1 = SqliteConnector(self.config) + + conn2 = SqliteConnector(self.config) + + assert conn1.equals(conn2) + + config2 = SqliteConnectorConfig( + dialect="sqlite", database="path_todb.db", table="different_table" + ).dict() + conn3 = SqliteConnector(config2) + + assert conn1.equals(conn3) + + @patch("pandasai.connectors.SqliteConnector._init_connection") + def test_two_connector_not_equal(self, mock_init_connection): + conn1 = SqliteConnector(self.config) + + config2 = SqliteConnectorConfig( + dialect="sqlite", database="path_todb2.db", table="yourtable" + ).dict() + conn3 = SqliteConnector(config2) + + assert not conn1.equals(conn3)