-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(directSqlPrompt): use connector directly if flag is set #731
Changes from all commits
d20204d
742b1b6
d3e896c
5537a7e
451a843
f65cc22
d7a7cc7
593d283
b36f5fb
36da80c
976bf4e
856c9fe
c62bcbb
8409742
a02f332
877ea38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Example of using PandasAI with a CSV file.""" | ||
|
||
from pandasai import SmartDatalake | ||
from pandasai.llm import OpenAI | ||
from pandasai.connectors import PostgreSQLConnector | ||
from pandasai.smart_dataframe import SmartDataframe | ||
|
||
|
||
# With a PostgreSQL database | ||
order = PostgreSQLConnector( | ||
config={ | ||
"host": "localhost", | ||
"port": 5432, | ||
"database": "testdb", | ||
"username": "postgres", | ||
"password": "123456", | ||
"table": "orders", | ||
} | ||
) | ||
|
||
order_details = PostgreSQLConnector( | ||
config={ | ||
"host": "localhost", | ||
"port": 5432, | ||
"database": "testdb", | ||
"username": "postgres", | ||
"password": "123456", | ||
"table": "order_details", | ||
} | ||
) | ||
|
||
products = PostgreSQLConnector( | ||
config={ | ||
"host": "localhost", | ||
"port": 5432, | ||
"database": "testdb", | ||
"username": "postgres", | ||
"password": "123456", | ||
"table": "products", | ||
} | ||
) | ||
|
||
|
||
llm = OpenAI("OPEN_API_KEY") | ||
|
||
|
||
order_details_smart_df = SmartDataframe( | ||
order_details, | ||
config={"llm": llm, "direct_sql": True}, | ||
description="Contain user order details", | ||
) | ||
|
||
|
||
df = SmartDatalake( | ||
[order_details_smart_df, order, products], | ||
config={"llm": llm, "direct_sql": True}, | ||
) | ||
response = df.chat("return orders with count of distinct products") | ||
print(response) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Analyze the data, using the provided dataframes (`dfs`). | ||
1. Prepare: Preprocessing and cleaning data if necessary | ||
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) | ||
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) | ||
{viz_library_type} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
You are provided with the following samples of sql tables data: | ||
|
||
<Tables> | ||
{tables} | ||
<Tables> | ||
|
||
<conversation> | ||
{conversation} | ||
</conversation> | ||
|
||
You are provided with following function that executes the sql query, | ||
<Function> | ||
def execute_sql_query(sql_query: str) -> pd.Dataframe | ||
"""his method connect to the database executes the sql query and returns the dataframe""" | ||
</Function> | ||
|
||
This is the initial python function. Do not change the params. | ||
|
||
```python | ||
# TODO import all the dependencies required | ||
import pandas as pd | ||
|
||
def analyze_data() -> dict: | ||
""" | ||
Analyze the data, using the provided dataframes (`dfs`). | ||
1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.) | ||
2. Process: execute the query using execute method available to you which returns dataframe | ||
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) | ||
{viz_library_type} | ||
At the end, return a dictionary of: | ||
{output_type_hint} | ||
""" | ||
``` | ||
|
||
Take a deep breath and reason step-by-step. Act as a senior data analyst. | ||
In the answer, you must never write the "technical" names of the tables. | ||
Based on the last message in the conversation: | ||
|
||
- return the updated analyze_data function wrapped within `python ` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file. |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -90,3 +90,18 @@ def __repr__(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"database={self._config.database} schema={str(self._config.dbSchema)} " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"table={self._config.table}>" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def equals(self, other): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(other, self.__class__): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._config.dialect, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._config.account, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._config.username, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._config.password, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) == ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
other._config.dialect, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
other._config.account, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
other._config.username, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
other._config.password, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+94
to
+107
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
- self._config.password,
+ hash(self._config.password),
) == (
other._config.dialect,
other._config.account,
other._config.username,
- other._config.password,
+ hash(other._config.password),
)
return False Commitable suggestion
Suggested change
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
import re | ||
import os | ||
import pandas as pd | ||
|
||
from pandasai.exceptions import MaliciousQueryError | ||
from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig | ||
from .base import BaseConnectorConfig | ||
from sqlalchemy import create_engine, text, select, asc | ||
|
@@ -360,6 +362,46 @@ def column_hash(self): | |
def fallback_name(self): | ||
return self._config.table | ||
|
||
def equals(self, other): | ||
if isinstance(other, self.__class__): | ||
return ( | ||
self._config.dialect, | ||
self._config.driver, | ||
self._config.host, | ||
self._config.port, | ||
self._config.username, | ||
self._config.password, | ||
) == ( | ||
other._config.dialect, | ||
other._config.driver, | ||
other._config.host, | ||
other._config.port, | ||
other._config.username, | ||
other._config.password, | ||
) | ||
return False | ||
|
||
def _is_sql_query_safe(self, query: str): | ||
infected_keywords = [ | ||
r"\bINSERT\b", | ||
r"\bUPDATE\b", | ||
r"\bDELETE\b", | ||
r"\bDROP\b", | ||
r"\bEXEC\b", | ||
r"\bALTER\b", | ||
r"\bCREATE\b", | ||
] | ||
|
||
return not any( | ||
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords | ||
) | ||
Comment on lines
+384
to
+397
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @coderabbitai can you suggest more? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ArslanSaleem Sure, here are a few more suggestions to improve SQL query safety:
Remember, no single technique is foolproof. A combination of these methods will provide a more robust defense against SQL injection attacks. |
||
|
||
def execute_direct_sql_query(self, sql_query): | ||
if not self._is_sql_query_safe(sql_query): | ||
raise MaliciousQueryError("Malicious query is generated in code") | ||
|
||
return pd.read_sql(sql_query, self._connection) | ||
Comment on lines
+399
to
+403
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
|
||
class SqliteConnector(SQLConnector): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,10 +155,27 @@ class UnSupportedLogicUnit(Exception): | |
Exception (Exception): UnSupportedLogicUnit | ||
""" | ||
|
||
|
||
class InvalidWorkspacePathError(Exception): | ||
""" | ||
Raised when the environment variable of workspace exist but path is invalid | ||
|
||
Args: | ||
Exception (Exception): InvalidWorkspacePathError | ||
""" | ||
|
||
|
||
class InvalidConfigError(Exception): | ||
""" | ||
Raised when config value is not appliable | ||
Args: | ||
Exception (Exception): InvalidConfigError | ||
""" | ||
|
||
|
||
class MaliciousQueryError(Exception): | ||
""" | ||
Raise error if malicious query is generated | ||
Args: | ||
Exception (Excpetion): MaliciousQueryError | ||
""" | ||
Comment on lines
163
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new exceptions <!-- This is an auto-generated comment by CodeRabbit -->
---
<!-- suggestion_start -->
<details>
<summary>Commitable suggestion</summary>
> [!IMPORTANT]
> Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.
```suggestion
Args:
Exception (Exception): InvalidWorkspacePathError
"""
class InvalidConfigError(Exception):
"""
Raised when the config value is not applicable.
Args:
Exception (Exception): InvalidConfigError
"""
class MaliciousQueryError(Exception):
"""
Raised when a malicious query is generated.
Args:
Exception (Exception): MaliciousQueryError
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" Prompt to explain code generation by the LLM | ||
The previous conversation we had | ||
|
||
<Conversation> | ||
{conversation} | ||
</Conversation> | ||
|
||
Based on the last conversation you generated the following code: | ||
|
||
<Code> | ||
{code} | ||
</Code> | ||
|
||
Explain how you came up with code for non-technical people without | ||
mentioning technical details or mentioning the libraries used? | ||
|
||
""" | ||
Comment on lines
+1
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This docstring seems to be misplaced or irrelevant to the |
||
from .file_based_prompt import FileBasedPrompt | ||
|
||
|
||
class DirectSQLPrompt(FileBasedPrompt): | ||
"""Prompt to explain code generation by the LLM""" | ||
|
||
_path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl" | ||
|
||
def _prepare_tables_data(self, tables): | ||
tables_join = [] | ||
for table in tables: | ||
table_description_tag = ( | ||
f' description="{table.table_description}"' | ||
if table.table_description is not None | ||
else "" | ||
) | ||
table_head_tag = f'<table name="{table.table_name}"{table_description_tag}>' | ||
table = f"{table_head_tag}\n{table.head_csv}\n</table>" | ||
tables_join.append(table) | ||
return "\n\n".join(tables_join) | ||
|
||
def setup(self, tables) -> None: | ||
self.set_var("tables", self._prepare_tables_data(tables)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
equals
method is a good addition for comparing two instances of theDatabricksConnector
class. However, it's important to note that this method only checks for equality based on a subset of the instance's properties. If there are other properties that could affect the behavior of the instance, they should be included in this comparison. Also, consider renaming the method to__eq__
to follow Python's convention for equality comparison, which would allow you to use the==
operator directly.Commitable suggestion