-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(directSqlPrompt): use connector directly if flag is set #731
Conversation
* In this commit, I introduced a new configuration parameter in our application settings that allows users to define their preferred data visualization library (matplotlib, seaborn, or plotly). With this update, I've eliminated the need for the user to specify in every prompt which library to use, thereby simplifying their interaction with the application and increasing its versatility. * This commit adds a configuration parameter for users to set their preferred data visualization library (matplotlib, seaborn, or plotly), simplifying interactions and enhancing the application's versatility. * viz_library_type' in test_generate_python_code_prompt.py, resolved failing tests --------- Co-authored-by: sabatino.severino <qrxqfspfibrth6nxywai2qifza6jmskt222howzew43risnx4kva> Co-authored-by: Gabriele Venturi <[email protected]>
* fix(chart): charts to save to save_chart_path * refactor sourcery changes * 'Refactored by Sourcery' * refactor chart save code * fix: minor leftovers * feat(workspace_env): add workspace env to store cache, temp chart and config * add error handling and comments --------- Co-authored-by: Sourcery AI <>
Important Auto Review SkippedAuto reviews are disabled on base/target branches other than the default branch. Please add the base/target branch pattern to the list of additional branches to be reviewed in the settings. Please check the settings in the CodeRabbit UI or the To trigger a single review, invoke the WalkthroughThe changes introduced in the codebase primarily revolve around the integration of direct SQL execution functionality in the PandasAI library. This includes the addition of new methods, classes, and exception handling, as well as modifications to existing classes and methods. The changes also involve the creation of new test cases to ensure the proper functioning of the newly added features. Changes
Poem
TipsChat with CodeRabbit Bot (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 9
Configuration used: CodeRabbit UI
Files selected for processing (17)
- examples/sql_direct_config.py (1 hunks)
- pandasai/assets/prompt_templates/direct_sql_connector.tmpl (1 hunks)
- pandasai/connectors/databricks.py (1 hunks)
- pandasai/connectors/snowflake.py (1 hunks)
- pandasai/connectors/sql.py (2 hunks)
- pandasai/exceptions.py (1 hunks)
- pandasai/helpers/code_manager.py (4 hunks)
- pandasai/prompts/direct_sql_prompt.py (1 hunks)
- pandasai/schemas/df_config.py (1 hunks)
- pandasai/smart_dataframe/init.py (1 hunks)
- pandasai/smart_datalake/init.py (11 hunks)
- tests/connectors/test_sql.py (2 hunks)
- tests/prompts/test_sql_prompt.py (1 hunks)
- tests/skills/test_skills.py (2 hunks)
- tests/test_codemanager.py (3 hunks)
- tests/test_smartdataframe.py (1 hunks)
- tests/test_smartdatalake.py (3 hunks)
Files skipped from review due to trivial changes (4)
- examples/sql_direct_config.py
- pandasai/schemas/df_config.py
- pandasai/smart_dataframe/init.py
- tests/test_smartdataframe.py
Additional comments: 36
pandasai/assets/prompt_templates/direct_sql_connector.tmpl (1)
- 1-39: This hunk is a template for generating a Python function
analyze_data()
. It provides a structure for the function but does not contain any actual implementation. The function is expected to analyze data using provided dataframes and return a dictionary of the analysis results. The function takes no parameters and has a return type hint ofdict
. The function is incomplete and requires implementation. The hunk also includes the import of thepandas
library.pandasai/prompts/direct_sql_prompt.py (2)
24-24: Ensure that the template file
direct_sql_connector.tmpl
exists at the specified path. If the path is relative, it should be relative to the location where the script is run.33-34: The
setup
method is setting a variabletables
but it's not clear where this variable is used. Ensure that this variable is used somewhere in the class or its parent class.pandasai/helpers/code_manager.py (4)
28-40: The
CodeExecutionContext
class has been updated to include a new attribute_can_direct_sql
and its corresponding parameter in the constructor. This attribute is used to determine whether direct SQL queries are allowed.46-48: The
_can_direct_sql
attribute is set in the constructor of theCodeExecutionContext
class.55-64: > Note: This review was outside of the patch, so it was mapped to the patch with the greatest overlap. Original lines [50-60]
The
can_direct_sql
property has been added to provide access to the_can_direct_sql
attribute.
- 294-301: In the
execute_code
method, ifcontext.can_direct_sql
is true, it sets theexecute_sql_query
environment variable to a query execution function and returns the result ofanalyze_data()
. Otherwise, it callsanalyze_data()
with the original dataframes. This allows for the execution of direct SQL queries when the flag is set.tests/prompts/test_sql_prompt.py (1)
- 51-113: The test case checks the
to_string()
method of theDirectSQLPrompt
class. It sets various variables and checks if the output string is as expected. The test case seems to be correct and should work as expected.pandasai/exceptions.py (1)
- 155-173: class InvalidWorkspacePathError(Exception):
"""
Raised when the environment variable of workspace exists but the path is invalid.
Args:
Exception (Exception): InvalidWorkspacePathError
"""class InvalidConfigError(Exception):
"""
Raised when a configuration value is not applicable or invalid.
Args:
Exception (Exception): InvalidConfigError
"""class MaliciousQueryError(Exception):
"""
Raised when a potentially harmful SQL query is detected.
Args:
Exception (Exception): MaliciousQueryError
"""</blockquote></details> <details><summary>tests/test_codemanager.py (2)</summary><blockquote> * 75-79: The `exec_context` fixture has been updated to create an instance of `CodeExecutionContext` with a UUID and a `SkillsManager` object. This change is likely related to the new feature that allows direct SQL queries. Ensure that the `CodeExecutionContext` object is correctly initialized and used in the tests. * 98-105: The `test_clean_code_remove_builtins` test case has been modified to set the `_can_direct_sql` attribute of `exec_context` to `False`. This change is likely related to the new feature that allows direct SQL queries. Ensure that the `_can_direct_sql` attribute is correctly set and used in the tests. </blockquote></details> <details><summary>pandasai/connectors/sql.py (3)</summary><blockquote> * 5-9: Imports are correctly placed and the new exception `MaliciousQueryError` is imported correctly. * 362-382: The `equals` method is correctly implemented to compare two instances of the same class based on their configuration properties. * 399-403: The `execute_direct_sql_query` method correctly uses the `_is_sql_query_safe` method to check if a query is safe before executing it. If the query is not safe, it raises a `MaliciousQueryError`. This is a good practice for preventing SQL injection attacks. </blockquote></details> <details><summary>tests/skills/test_skills.py (2)</summary><blockquote> * 71-74: The `exec_context` fixture now returns a `CodeExecutionContext` object with a UUID and a `SkillsManager` instance. This change is in line with the new feature that allows direct SQL queries. The `CodeExecutionContext` object is used to manage the execution context of the code, including the skills manager, which is responsible for managing the skills used in the code. This change is necessary to support the new feature. * 318-320: The `test_code_exec_with_skills_no_use` test method no longer specifies the type of the `exec_context` parameter. This change is likely due to the update in the `exec_context` fixture. The test method now accepts any object as the `exec_context` parameter, which increases flexibility but may also increase the risk of type errors. It would be good to ensure that all calls to this method pass an object of the correct type. </blockquote></details> <details><summary>tests/test_smartdatalake.py (1)</summary><blockquote> * 9-23: The new import statements are correctly placed and follow the PEP 8 style guide. The new classes `DirectSQLPrompt` and `GeneratePythonCodePrompt` are imported correctly. </blockquote></details> <details><summary>tests/connectors/test_sql.py (8)</summary><blockquote> * 105-107: The test for the `fallback_name` property is correct and well implemented. * 109-112: The test for the `_is_sql_query_safe` method with a safe query is correct and well implemented. * 114-117: The test for the `_is_sql_query_safe` method with a malicious query is correct and well implemented. * 119-125: The test for the `execute_direct_sql_query` method with a safe query is correct and well implemented. * 127-133: The test for the `execute_direct_sql_query` method with a malicious query is correct and well implemented. It correctly checks for the `MaliciousQueryError` exception. * 135-153: The test for the `equals` method with identical configs is correct and well implemented. * 155-176: The test for the `equals` method with different configs is correct and well implemented. * 178-196: The test for the `equals` method with a different connector is correct and well implemented. </blockquote></details> <details><summary>pandasai/smart_datalake/__init__.py (11)</summary><blockquote> * 23-31: The import of `DirectSQLPrompt` is added. Ensure that the module exists and is accessible. * 40-49: The import of `InvalidConfigError` is added. Ensure that the module exists and is accessible. * 62-68: The `_can_direct_sql` attribute is added to the `SmartDatalake` class. This attribute is used to check if direct SQL can be used. * 132-140: The `_validate_direct_sql` method is added to validate the direct SQL configuration. It checks if all dataframes belong to the same SQL connector type and raises an `InvalidConfigError` if they don't. * 249-281: The `_get_chat_prompt` method is added to determine the chat prompt based on the direct SQL configuration. It returns a tuple with the key and the prompt. * 306-312: The `_get_prompt` method is modified to accept a new argument `key` and use it to get the custom prompt from the configuration. * 325-331: The `_get_prompt` method is modified to set the configuration for the prompt. * 358-364: The `_get_cache_key` method is modified to append "direct_sql" to the cache key if direct SQL is enabled. * 432-443: The `chat` method is modified to use the `_get_chat_prompt` method to determine the chat prompt. * 470-478: The `chat` method is modified to pass the `_can_direct_sql` parameter to `CodeExecutionContext`. * 593-599: The `correct_error` prompt is modified to use the `CorrectErrorPrompt` class as the default prompt. </blockquote></details></blockquote></details> </details>
Args: | ||
Exception (Exception): InvalidWorkspacePathError | ||
""" | ||
|
||
|
||
class InvalidConfigError(Exception): | ||
""" | ||
Raised when config value is not appliable | ||
Args: | ||
Exception (Exception): InvalidConfigError | ||
""" | ||
|
||
|
||
class MaliciousQueryError(Exception): | ||
""" | ||
Raise error if malicious query is generated | ||
Args: | ||
Exception (Excpetion): MaliciousQueryError | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new exceptions InvalidWorkspacePathError
, InvalidConfigError
, and MaliciousQueryError
are well defined and follow the existing pattern of exception definitions in the file. However, the docstrings could be more descriptive to provide more context about when these exceptions are raised.
<!-- This is an auto-generated comment by CodeRabbit -->
---
<!-- suggestion_start -->
<details>
<summary>Commitable suggestion</summary>
> [!IMPORTANT]
> Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.
```suggestion
Args:
Exception (Exception): InvalidWorkspacePathError
"""
class InvalidConfigError(Exception):
"""
Raised when the config value is not applicable.
Args:
Exception (Exception): InvalidConfigError
"""
class MaliciousQueryError(Exception):
"""
Raised when a malicious query is generated.
Args:
Exception (Exception): MaliciousQueryError
"""
def equals(self, other): | ||
if isinstance(other, self.__class__): | ||
return ( | ||
self._config.dialect, | ||
self._config.account, | ||
self._config.username, | ||
self._config.password, | ||
) == ( | ||
other._config.dialect, | ||
other._config.account, | ||
other._config.username, | ||
other._config.password, | ||
) | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The equals
method is comparing sensitive information like passwords. This could potentially lead to security issues if the method is used inappropriately. Consider comparing a hash of the password instead of the password itself. Also, consider adding a docstring to this method to explain its purpose and usage.
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
- self._config.password,
+ hash(self._config.password),
) == (
other._config.dialect,
other._config.account,
other._config.username,
- other._config.password,
+ hash(other._config.password),
)
return False
Commitable suggestion
[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.
def equals(self, other): | |
if isinstance(other, self.__class__): | |
return ( | |
self._config.dialect, | |
self._config.account, | |
self._config.username, | |
self._config.password, | |
) == ( | |
other._config.dialect, | |
other._config.account, | |
other._config.username, | |
other._config.password, | |
) | |
return False | |
def equals(self, other): | |
""" | |
Compare the current object with another object for equality. | |
Args: | |
other: The object to compare with. | |
Returns: | |
True if the objects are equal, False otherwise. | |
""" | |
if isinstance(other, self.__class__): | |
return ( | |
self._config.dialect, | |
self._config.account, | |
self._config.username, | |
hash(self._config.password), | |
) == ( | |
other._config.dialect, | |
other._config.account, | |
other._config.username, | |
hash(other._config.password), | |
) | |
return False |
""" Prompt to explain code generation by the LLM | ||
The previous conversation we had | ||
|
||
<Conversation> | ||
{conversation} | ||
</Conversation> | ||
|
||
Based on the last conversation you generated the following code: | ||
|
||
<Code> | ||
{code} | ||
</Code> | ||
|
||
Explain how you came up with code for non-technical people without | ||
mentioning technical details or mentioning the libraries used? | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring seems to be misplaced or irrelevant to the DirectSQLPrompt
class. It should be updated to reflect the purpose and usage of the DirectSQLPrompt
class.
def _prepare_tables_data(self, tables): | ||
tables_join = [] | ||
for table in tables: | ||
table = f"<table>{table.head_csv}</table>" | ||
tables_join.append(table) | ||
return "\n".join(tables_join) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is joining table data with <table>
tags. If the head_csv
attribute of the table
object is a string representation of CSV data, this could lead to malformed HTML if the CSV data contains special characters. Consider escaping special characters or using a library to generate the HTML.
tests/prompts/test_sql_prompt.py
Outdated
from pandasai.helpers.output_types import ( | ||
output_type_factory, | ||
DefaultOutputType, | ||
output_types_map, | ||
) | ||
|
||
|
||
class TestDirectSqlPrompt: | ||
"""Unit tests for the correct error prompt class""" | ||
|
||
@pytest.mark.parametrize( | ||
"save_charts_path,output_type_hint,viz_library_type_hint", | ||
[ | ||
( | ||
"exports/charts", | ||
DefaultOutputType().template_hint, | ||
MatplotlibVizLibraryType().template_hint, | ||
), | ||
( | ||
"custom/dir/for/charts", | ||
DefaultOutputType().template_hint, | ||
MatplotlibVizLibraryType().template_hint, | ||
), | ||
*[ | ||
( | ||
"exports/charts", | ||
output_type_factory(type_).template_hint, | ||
viz_lib_type_factory(viz_type_).template_hint, | ||
) | ||
for type_ in output_types_map | ||
for viz_type_ in viz_lib_map | ||
], | ||
], | ||
) | ||
def test_direct_sql_prompt_with_params( | ||
self, save_charts_path, output_type_hint, viz_library_type_hint | ||
): | ||
"""Test that the __str__ method is implemented""" | ||
|
||
llm = FakeLLM("plt.show()") | ||
dfs = [ | ||
SmartDataframe( | ||
pd.DataFrame({}), | ||
config={"llm": llm}, | ||
) | ||
] | ||
|
||
prompt = DirectSQLPrompt(tables=dfs) | ||
prompt.set_var("dfs", dfs) | ||
prompt.set_var("conversation", "What is the correct code?") | ||
prompt.set_var("output_type_hint", output_type_hint) | ||
prompt.set_var("save_charts_path", save_charts_path) | ||
prompt.set_var("viz_library_type", viz_library_type_hint) | ||
prompt_content = prompt.to_string() | ||
if sys.platform.startswith("win"): | ||
prompt_content = prompt_content.replace("\r\n", "\n") | ||
|
||
assert ( | ||
prompt_content | ||
== f'''You are provided with the following samples of sql tables data: | ||
|
||
<Tables> | ||
<table> | ||
</table> | ||
<Tables> | ||
|
||
<conversation> | ||
What is the correct code? | ||
</conversation> | ||
|
||
You are provided with following function that executes the sql query, | ||
<Function> | ||
def execute_sql_query(sql_query: str) -> pd.Dataframe | ||
"""his method connect to the database executes the sql query and returns the dataframe""" | ||
</Function> | ||
|
||
This is the initial python function. Do not change the params. | ||
|
||
```python | ||
# TODO import all the dependencies required | ||
import pandas as pd | ||
|
||
def analyze_data() -> dict: | ||
""" | ||
Analyze the data, using the provided dataframes (`dfs`). | ||
1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.) | ||
2. Process: execute the query using execute method available to you which returns dataframe | ||
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) | ||
{viz_library_type_hint} | ||
At the end, return a dictionary of: | ||
{output_type_hint} | ||
""" | ||
``` | ||
|
||
Take a deep breath and reason step-by-step. Act as a senior data analyst. | ||
In the answer, you must never write the "technical" names of the tables. | ||
Based on the last message in the conversation: | ||
|
||
- return the updated analyze_data function wrapped within `python `''' # noqa: E501 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test case seems to be well written and covers a variety of scenarios by using parameterized inputs. However, the test case only checks the to_string()
method of the DirectSQLPrompt
class. It would be beneficial to add more test cases to cover other methods of the DirectSQLPrompt
class to ensure full coverage.
f"host={self._config.host} port={self._config.port} " | ||
f"database={self._config.database} httpPath={str(self._config.httpPath)}" | ||
) | ||
|
||
def equals(self, other): | ||
if isinstance(other, self.__class__): | ||
return ( | ||
self._config.dialect, | ||
self._config.token, | ||
self._config.host, | ||
self._config.port, | ||
self._config.httpPath, | ||
) == ( | ||
other._config.dialect, | ||
other._config.token, | ||
other._config.host, | ||
other._config.port, | ||
other._config.httpPath, | ||
) | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The equals
method is a good addition for comparing two instances of the DatabricksConnector
class. However, it's important to note that this method only checks for equality based on a subset of the instance's properties. If there are other properties that could affect the behavior of the instance, they should be included in this comparison. Also, consider renaming the method to __eq__
to follow Python's convention for equality comparison, which would allow you to use the ==
operator directly.
- def equals(self, other):
+ def __eq__(self, other):
Commitable suggestion
[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.
f"host={self._config.host} port={self._config.port} " | |
f"database={self._config.database} httpPath={str(self._config.httpPath)}" | |
) | |
def equals(self, other): | |
if isinstance(other, self.__class__): | |
return ( | |
self._config.dialect, | |
self._config.token, | |
self._config.host, | |
self._config.port, | |
self._config.httpPath, | |
) == ( | |
other._config.dialect, | |
other._config.token, | |
other._config.host, | |
other._config.port, | |
other._config.httpPath, | |
) | |
return False | |
def __eq__(self, other): | |
if isinstance(other, self.__class__): | |
return ( | |
self._config.dialect, | |
self._config.token, | |
self._config.host, | |
self._config.port, | |
self._config.httpPath, | |
) == ( | |
other._config.dialect, | |
other._config.token, | |
other._config.host, | |
other._config.port, | |
other._config.httpPath, | |
) | |
return False |
@pytest.fixture | ||
@patch("pandasai.connectors.sql.create_engine", autospec=True) | ||
def sql_connector(self, create_engine): | ||
# Define your ConnectorConfig instance here | ||
self.config = SQLConnectorConfig( | ||
dialect="mysql", | ||
driver="pymysql", | ||
username="your_username", | ||
password="your_password", | ||
host="your_host", | ||
port=443, | ||
database="your_database", | ||
table="your_table", | ||
where=[["column_name", "=", "value"]], | ||
).dict() | ||
|
||
# Create an instance of SQLConnector | ||
return SQLConnector(self.config) | ||
|
||
@pytest.fixture | ||
@patch("pandasai.connectors.sql.create_engine", autospec=True) | ||
def pgsql_connector(self, create_engine): | ||
# Define your ConnectorConfig instance here | ||
self.config = SQLConnectorConfig( | ||
dialect="mysql", | ||
driver="pymysql", | ||
username="your_username", | ||
password="your_password", | ||
host="your_host", | ||
port=443, | ||
database="your_database", | ||
table="your_table", | ||
where=[["column_name", "=", "value"]], | ||
).dict() | ||
|
||
# Create an instance of SQLConnector | ||
return PostgreSQLConnector(self.config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new fixtures sql_connector
and pgsql_connector
are correctly defined. However, the credentials for the SQL databases are hardcoded. This could be a security issue if this code is pushed to a public repository. Consider using environment variables or a secure method to store these credentials.
- username="your_username",
- password="your_password",
- host="your_host",
+ username=os.getenv("DB_USERNAME"),
+ password=os.getenv("DB_PASSWORD"),
+ host=os.getenv("DB_HOST"),
Commitable suggestion
[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.
@pytest.fixture | |
@patch("pandasai.connectors.sql.create_engine", autospec=True) | |
def sql_connector(self, create_engine): | |
# Define your ConnectorConfig instance here | |
self.config = SQLConnectorConfig( | |
dialect="mysql", | |
driver="pymysql", | |
username="your_username", | |
password="your_password", | |
host="your_host", | |
port=443, | |
database="your_database", | |
table="your_table", | |
where=[["column_name", "=", "value"]], | |
).dict() | |
# Create an instance of SQLConnector | |
return SQLConnector(self.config) | |
@pytest.fixture | |
@patch("pandasai.connectors.sql.create_engine", autospec=True) | |
def pgsql_connector(self, create_engine): | |
# Define your ConnectorConfig instance here | |
self.config = SQLConnectorConfig( | |
dialect="mysql", | |
driver="pymysql", | |
username="your_username", | |
password="your_password", | |
host="your_host", | |
port=443, | |
database="your_database", | |
table="your_table", | |
where=[["column_name", "=", "value"]], | |
).dict() | |
# Create an instance of SQLConnector | |
return PostgreSQLConnector(self.config) | |
@pytest.fixture | |
@patch("pandasai.connectors.sql.create_engine", autospec=True) | |
def sql_connector(self, create_engine): | |
# Define your ConnectorConfig instance here | |
self.config = SQLConnectorConfig( | |
dialect="mysql", | |
driver="pymysql", | |
username=os.getenv("DB_USERNAME"), | |
password=os.getenv("DB_PASSWORD"), | |
host=os.getenv("DB_HOST"), | |
port=443, | |
database="your_database", | |
table="your_table", | |
where=[["column_name", "=", "value"]], | |
).dict() | |
# Create an instance of SQLConnector | |
return SQLConnector(self.config) | |
@pytest.fixture | |
@patch("pandasai.connectors.sql.create_engine", autospec=True) | |
def pgsql_connector(self, create_engine): | |
# Define your ConnectorConfig instance here | |
self.config = SQLConnectorConfig( | |
dialect="mysql", | |
driver="pymysql", | |
username=os.getenv("DB_USERNAME"), | |
password=os.getenv("DB_PASSWORD"), | |
host=os.getenv("DB_HOST"), | |
port=443, | |
database="your_database", | |
table="your_table", | |
where=[["column_name", "=", "value"]], | |
).dict() | |
# Create an instance of SQLConnector | |
return PostgreSQLConnector(self.config) |
smart_datalake.chat("How many countries are in the dataframe?") | ||
assert smart_datalake.last_answer == "Custom answer" | ||
assert smart_datalake.last_reasoning == "Custom reasoning" | ||
|
||
def test_get_chat_prompt(self, smart_datalake: SmartDatalake): | ||
# Test case 1: direct_sql is True | ||
smart_datalake._config.direct_sql = True | ||
gen_key, gen_prompt = smart_datalake._get_chat_prompt() | ||
expected_key = "direct_sql_prompt" | ||
assert gen_key == expected_key | ||
assert isinstance(gen_prompt, DirectSQLPrompt) | ||
|
||
# Test case 2: direct_sql is False | ||
smart_datalake._config.direct_sql = False | ||
gen_key, gen_prompt = smart_datalake._get_chat_prompt() | ||
expected_key = "generate_python_code" | ||
assert gen_key == expected_key | ||
assert isinstance(gen_prompt, GeneratePythonCodePrompt) | ||
|
||
def test_validate_true_direct_sql_with_non_connector(self, llm, sample_df): | ||
# raise exception with non connector | ||
SmartDatalake( | ||
[sample_df], | ||
config={"llm": llm, "enable_cache": False, "direct_sql": True}, | ||
) | ||
|
||
def test_validate_direct_sql_with_connector(self, llm, sql_connector): | ||
# not exception is raised using single connector | ||
SmartDatalake( | ||
[sql_connector], | ||
config={"llm": llm, "enable_cache": False, "direct_sql": True}, | ||
) | ||
|
||
def test_validate_false_direct_sql_with_connector(self, llm, sql_connector): | ||
# not exception is raised using single connector | ||
SmartDatalake( | ||
[sql_connector], | ||
config={"llm": llm, "enable_cache": False, "direct_sql": False}, | ||
) | ||
|
||
def test_validate_false_direct_sql_with_two_different_connector( | ||
self, llm, sql_connector, pgsql_connector | ||
): | ||
# not exception is raised using single connector | ||
SmartDatalake( | ||
[sql_connector, pgsql_connector], | ||
config={"llm": llm, "enable_cache": False, "direct_sql": False}, | ||
) | ||
|
||
def test_validate_true_direct_sql_with_two_different_connector( | ||
self, llm, sql_connector, pgsql_connector | ||
): | ||
# not exception is raised using single connector | ||
# raise exception when two different connector | ||
with pytest.raises(InvalidConfigError): | ||
SmartDatalake( | ||
[sql_connector, pgsql_connector], | ||
config={"llm": llm, "enable_cache": False, "direct_sql": True}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new test methods are correctly defined and cover the new functionality added to the SmartDatalake
class. The tests cover both positive and negative scenarios and check for correct exceptions. However, the comments for the test cases are not very descriptive. Consider adding more detailed comments to explain what each test case is checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 2
Configuration used: CodeRabbit UI
Files selected for processing (2)
- pandasai/connectors/sql.py (2 hunks)
- pandasai/smart_datalake/init.py (11 hunks)
Additional comments: 11
pandasai/connectors/sql.py (1)
- 5-9: The import of
MaliciousQueryError
is necessary for the new feature of executing direct SQL queries. This exception will be raised if a potentially harmful SQL query is detected.pandasai/smart_datalake/__init__.py (10)
23-31: The import statement for
DirectSQLPrompt
has been added. Ensure that theDirectSQLPrompt
class is correctly implemented and that it is in the correct location for this import statement to work.40-49: The
CodeExecutionContext
import has been added. Ensure that theCodeExecutionContext
class is correctly implemented and that it is in the correct location for this import statement to work.62-68: The
_can_direct_sql
attribute has been added to theSmartDatalake
class. This attribute is used to determine whether direct SQL queries are allowed. Ensure that this attribute is correctly used throughout the class.132-140: The
_validate_direct_sql
method has been added to validate the direct SQL configuration. This method checks if all dataframes belong to the same SQL connector type and raises anInvalidConfigError
if they do not. Ensure that this method is correctly used throughout the class.249-282: The
_get_chat_prompt
method has been added to determine the chat prompt based on the direct SQL configuration. This method returns a tuple containing the key for the prompt and the prompt itself. Ensure that this method is correctly used throughout the class.306-311: The
_get_prompt
method has been modified to include the_can_direct_sql
flag. This flag is passed to theCodeExecutionContext
constructor. Ensure that this flag is correctly used in theCodeExecutionContext
class.325-331: The
_get_prompt
method has been modified to include the_can_direct_sql
flag. This flag is passed to theCodeExecutionContext
constructor. Ensure that this flag is correctly used in theCodeExecutionContext
class.432-443: The
chat
method has been modified to use the_get_chat_prompt
method and pass the tables to the prompt. Ensure that this method is correctly used throughout the class.470-478: The
chat
method has been modified to pass the_can_direct_sql
flag to theCodeExecutionContext
constructor. Ensure that this flag is correctly used in theCodeExecutionContext
class.593-599: The
_retry_run_code
method has been modified to include the_can_direct_sql
flag. This flag is passed to theCodeExecutionContext
constructor. Ensure that this flag is correctly used in theCodeExecutionContext
class.
def _is_sql_query_safe(self, query: str): | ||
infected_keywords = [ | ||
r"\bINSERT\b", | ||
r"\bUPDATE\b", | ||
r"\bDELETE\b", | ||
r"\bDROP\b", | ||
r"\bEXEC\b", | ||
r"\bALTER\b", | ||
r"\bCREATE\b", | ||
] | ||
|
||
return not any( | ||
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _is_sql_query_safe
method checks for potentially harmful SQL keywords in a query. This is a good security practice to prevent SQL injection attacks. However, it's important to note that this method only checks for a limited set of keywords and may not catch all possible malicious queries. Consider using a more comprehensive SQL injection prevention technique or library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@coderabbitai can you suggest more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArslanSaleem Sure, here are a few more suggestions to improve SQL query safety:
-
Parameterized Queries: This is one of the most effective ways to prevent SQL injection. Instead of constructing SQL queries with string concatenation, use parameterized queries. This ensures that user-supplied input is always treated as literal values and not part of the SQL command.
-
ORM (Object-Relational Mapping) Libraries: Using an ORM library can abstract the SQL layer, reducing the risk of SQL injection attacks. These libraries generate SQL queries on your behalf and are generally designed to avoid SQL injection.
-
Least Privilege Principle: Limit the permissions of the database account used by the application. Don't use a database root account whose compromise could lead to full database access.
-
Regular Expressions: While the current method uses regular expressions to check for malicious keywords, it could be expanded to include more patterns. However, this method should not be relied upon solely as it can be bypassed with various techniques.
-
Web Application Firewalls (WAFs): These can help to filter out malicious data and provide an additional layer of security.
-
Input Validation: Validate input data thoroughly. Use both whitelist (only allow known good patterns) and blacklist (block known bad patterns) validation.
Remember, no single technique is foolproof. A combination of these methods will provide a more robust defense against SQL injection attacks.
def execute_direct_sql_query(self, sql_query): | ||
if not self._is_sql_query_safe(sql_query): | ||
raise MaliciousQueryError("Malicious query is generated in code") | ||
|
||
return pd.read_sql(sql_query, self._connection) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The execute_direct_sql_query
method executes a SQL query directly. It first checks if the query is safe using the _is_sql_query_safe
method and raises a MaliciousQueryError
if the query is determined to be malicious. This is a good practice for security. However, as mentioned above, the _is_sql_query_safe
method may not catch all possible malicious queries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ArslanSaleem, tried it out, but it seems to use the wrong table name.
Example: "SELECT COUNT(*) AS total FROM table2"
The table is not called table2, and I suppose the problem is we don't pass the table name at all, can you confirm?
…nnector feat(directSqlPrompt): use connector directly if flag is set (Sourcery refactored)
Codecov Report
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. @@ Coverage Diff @@
## release/v1.5 #731 +/- ##
===============================================
Coverage ? 85.19%
===============================================
Files ? 86
Lines ? 3817
Branches ? 0
===============================================
Hits ? 3252
Misses ? 565
Partials ? 0 📣 Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today! |
…enturi/pandas-ai into feat/use_direct_sql_connector
Summary by CodeRabbit
New Features
DirectSQLPrompt
class for handling SQL prompts.Bug Fixes
Tests
Documentation
Refactor
lazy_load_connector
attribute and addeddirect_sql
attribute in theConfig
class.Chores