Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[DatabricksConnector]: connector to connect to Databricks on Cloud #580

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix: Databricks connection issues
  • Loading branch information
ArslanSaleem committed Sep 20, 2023
commit 70c71ed2c6afbd19f1eb72467f37866e9ad58683
16 changes: 8 additions & 8 deletions examples/from_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@

databricks_connector = DatabricksConnector(
config={
"host": "ehxzojy-ue47135",
"database": "SNOWFLAKE_SAMPLE_DATA",
"token": "",
"host": "adb-*****.azuredatabricks.net",
"database": "default",
"token": "dapidfd412321",
"port": 443,
"table": "lineitem",
"httpPath": "tpch_sf1",
"table": "loan_payments_data",
"httpPath": "/sql/1.0/warehouses/213421312",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["l_quantity", ">", "49"]
["loan_status", "=", "PAIDOFF"],
],
}
)

llm = OpenAI(api_token="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")
llm = OpenAI("OPEN_API_KEY")
gventuri marked this conversation as resolved.
Show resolved Hide resolved
df = SmartDataframe(databricks_connector, config={"llm": llm})

response = df.chat("How many records has status 'F'?")
response = df.chat("How many people from the United states?")
print(response)
14 changes: 2 additions & 12 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,7 @@ def rows_count(self):
)

# Run a SQL query to get the number of rows
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
"WHERE table_name = :table_name"
).bindparams(table_name=self._config.table)
query = sql.text(f"SELECT COUNT(*) FROM {self._config.table}")

# Return the number of rows
self._rows_count = self._connection.execute(query).fetchone()[0]
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new hunk simplifies the SQL query for row count, which is a good change. However, it assumes that the table name in the configuration does not contain any special characters or reserved words that might need to be escaped. This could lead to SQL injection vulnerabilities if the table name is user-supplied and not properly sanitized. Consider using bind parameters as in the old hunk to mitigate this risk.

- query = sql.text(f"SELECT COUNT(*) FROM {self._config.table}")
+ query = sql.text("SELECT COUNT(*) FROM :table_name").bindparams(table_name=self._config.table)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this could create a potential attack. We are using bindparams to validate the table name. Also the suggestion from @coderabbitai might work!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gventuri: I'm glad you agree with the suggestion. Using bind parameters is indeed a good practice to prevent SQL injection attacks. It ensures that the table name is properly escaped, preventing malicious input from being executed as part of the SQL query. Here's how you can modify your code:

query = sql.text("SELECT COUNT(*) FROM :table_name").bindparams(table_name=self._config.table)

This way, SQLAlchemy will handle the proper escaping of the table name, making your code safer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem: You're correct. My apologies for the confusion. SQLAlchemy's bindparams is indeed used to bind parameters, not table names or column names.

To safely include a table name in a query, we need to use SQLAlchemy's Table construct which provides proper quoting around identifiers. Here's how you can modify your code:

from sqlalchemy import Table, select, MetaData

metadata = MetaData()
table = Table(self._config.table, metadata)
query = select([func.count()]).select_from(table)

self._rows_count = self._connection.execute(query).scalar()

This way, SQLAlchemy will handle the proper quoting of the table name, making your code safer.

Expand All @@ -307,14 +304,7 @@ def columns_count(self):
f"{self._config.dialect}"
)

# Run a SQL query to get the number of columns
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
f"WHERE table_name = '{self._config.table}'"
)

# Return the number of columns
self._columns_count = self._connection.execute(query).fetchone()[0]
self._columns_count = len(self.head().columns)
return self._columns_count
gventuri marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines 304 to 308
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The columns_count method now uses the len function on the columns returned by the head method instead of running a SQL query. This change could improve performance by reducing the number of database queries. However, make sure that the head method always returns the correct and complete set of columns.


def _get_column_hash(self, include_additional_filters: bool = False):
Expand Down
392 changes: 234 additions & 158 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ coverage = "^7.2.7"
google-cloud-aiplatform = "^1.26.1"

[tool.poetry.extras]
connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy", "databricks-sql-connector"]
connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy", "sqlalchemy-databricks"]
google-ai = ["google-generativeai", "google-cloud-aiplatform"]
google-sheets = ["beautifulsoup4"]
excel = ["openpyxl"]
Expand Down
8 changes: 4 additions & 4 deletions tests/connectors/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_rows_count_property(self):
def test_columns_count_property(self):
# Test columns_count property
self.connector._columns_count = None
self.mock_connection.execute.return_value.fetchone.return_value = (
8,
) # Sample columns count
mock_df = Mock()
mock_df.columns = ["Column1", "Column2"]
self.connector.head = Mock(return_value=mock_df)
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 8)
self.assertEqual(columns_count, 2)

def test_column_hash_property(self):
# Test column_hash property
Expand Down
8 changes: 4 additions & 4 deletions tests/connectors/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_rows_count_property(self):
def test_columns_count_property(self):
# Test columns_count property
self.connector._columns_count = None
self.mock_connection.execute.return_value.fetchone.return_value = (
8,
) # Sample columns count
mock_df = Mock()
mock_df.columns = ["Column1", "Column2"]
self.connector.head = Mock(return_value=mock_df)
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 8)
self.assertEqual(columns_count, 2)
Comment on lines 82 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method to test columns_count property has been changed. Previously, it was directly fetching the count from the database, but now it's using the head method to get a dataframe and then counting the columns. This change might affect the performance if the table has a large number of rows. Consider reverting back to the old method if performance is a concern.


def test_column_hash_property(self):
# Test column_hash property
Expand Down
8 changes: 4 additions & 4 deletions tests/connectors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_rows_count_property(self):
def test_columns_count_property(self):
# Test columns_count property
self.connector._columns_count = None
self.mock_connection.execute.return_value.fetchone.return_value = (
8,
) # Sample columns count
mock_df = Mock()
mock_df.columns = ["Column1", "Column2"]
self.connector.head = Mock(return_value=mock_df)
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 8)
self.assertEqual(columns_count, 2)

def test_column_hash_property(self):
# Test column_hash property
Expand Down