Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(directSqlPrompt): use connector directly if flag is set #731

Merged
merged 16 commits into from
Nov 7, 2023

Conversation

ArslanSaleem
Copy link
Collaborator

@ArslanSaleem ArslanSaleem commented Nov 6, 2023

Summary by CodeRabbit

  • New Features

    • Introduced direct SQL query execution functionality.
    • Added new DirectSQLPrompt class for handling SQL prompts.
    • Implemented new methods for comparing configuration properties of instances.
  • Bug Fixes

    • Added safety checks for SQL queries to prevent malicious queries.
  • Tests

    • Added new test cases for the direct SQL functionality, safety checks, and configuration comparison.
  • Documentation

    • Provided examples demonstrating the usage of the library with a CSV file.
  • Refactor

    • Removed lazy_load_connector attribute and added direct_sql attribute in the Config class.
  • Chores

    • Introduced new exceptions for specific error scenarios.

bsab and others added 9 commits October 31, 2023 15:26
* In this commit, I introduced a new configuration parameter in our application settings that allows users to define their preferred data visualization library (matplotlib, seaborn, or plotly).
With this update, I've eliminated the need for the user to specify in every prompt which library to use, thereby simplifying their interaction with the application and increasing its versatility.

* This commit adds a configuration parameter for users to set their preferred data visualization library (matplotlib, seaborn, or plotly), simplifying interactions and enhancing the application's versatility.

* viz_library_type' in test_generate_python_code_prompt.py, resolved failing tests

---------

Co-authored-by: sabatino.severino <qrxqfspfibrth6nxywai2qifza6jmskt222howzew43risnx4kva>
Co-authored-by: Gabriele Venturi <[email protected]>
* fix(chart): charts to save to save_chart_path

* refactor sourcery changes

* 'Refactored by Sourcery'

* refactor chart save code

* fix: minor leftovers

* feat(workspace_env): add workspace env to store cache, temp chart and config

* add error handling and comments

---------

Co-authored-by: Sourcery AI <>
Copy link
Contributor

coderabbitai bot commented Nov 6, 2023

Important

Auto Review Skipped

Auto reviews are disabled on base/target branches other than the default branch. Please add the base/target branch pattern to the list of additional branches to be reviewed in the settings.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository.

To trigger a single review, invoke the @coderabbitai review command.

Walkthrough

The changes introduced in the codebase primarily revolve around the integration of direct SQL execution functionality in the PandasAI library. This includes the addition of new methods, classes, and exception handling, as well as modifications to existing classes and methods. The changes also involve the creation of new test cases to ensure the proper functioning of the newly added features.

Changes

File(s) Change Summary
examples/sql_direct_config.py Demonstrates the usage of PandasAI with a CSV file and PostgreSQL database.
pandasai/assets/.../direct_sql_connector.tmpl Introduces a new function analyze_data() for data analysis.
pandasai/connectors/databricks.py, pandasai/connectors/snowflake.py Adds a new equals method to compare configuration properties of class instances.
pandasai/connectors/sql.py Adds methods to check if a SQL query is safe and to execute a SQL query directly.
pandasai/exceptions.py Introduces new exception classes for specific error scenarios.
pandasai/helpers/code_manager.py Adds a new attribute and property to the CodeExecutionContext class and modifies the execute_code method.
pandasai/prompts/direct_sql_prompt.py Introduces a new class DirectSQLPrompt and adds a method to prepare table data.
pandasai/schemas/df_config.py Removes the lazy_load_connector attribute and adds a new attribute direct_sql.
pandasai/smart_dataframe/__init__.py Adds new methods to compare class instances and to execute a direct SQL query.
pandasai/smart_datalake/__init__.py Adds new methods and attributes to handle direct SQL prompts and modifies existing methods.
tests/.../test_sql.py, tests/prompts/test_sql_prompt.py, tests/skills/test_skills.py, tests/test_codemanager.py, tests/test_smartdataframe.py, tests/test_smartdatalake.py Adds new test cases and modifies existing ones to test the newly added features and methods.

Poem

🍂 As the leaves fall this November day, 🍁

In the world of code, we hop and play. 🐇💻

SQL queries direct and fast, 🚀

Making data analysis a blast. 💥

With each new function, class, and test, 🧪

We strive to make our code the best. 🏆

So here's to changes, big and small, 🎉

In the spirit of autumn, we celebrate them all! 🍂🎊


Tips

Chat with CodeRabbit Bot (@coderabbitai)

  • If you reply to a review comment from CodeRabbit, the bot will automatically respond.
  • To engage with CodeRabbit bot directly around the specific lines of code in the PR, mention @coderabbitai in your review comment
  • Note: Review comments are made on code diffs or files, not on the PR overview.
  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai help to get help.
  • @coderabbitai resolve to resolve all the CodeRabbit review comments.

Note: For conversation with the bot, please use the review comments on code diffs or files.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • The JSON schema for the configuration file is available here.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/coderabbit-overrides.v2.json

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 9

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between 3885c08 and b36f5fb.
Files selected for processing (17)
  • examples/sql_direct_config.py (1 hunks)
  • pandasai/assets/prompt_templates/direct_sql_connector.tmpl (1 hunks)
  • pandasai/connectors/databricks.py (1 hunks)
  • pandasai/connectors/snowflake.py (1 hunks)
  • pandasai/connectors/sql.py (2 hunks)
  • pandasai/exceptions.py (1 hunks)
  • pandasai/helpers/code_manager.py (4 hunks)
  • pandasai/prompts/direct_sql_prompt.py (1 hunks)
  • pandasai/schemas/df_config.py (1 hunks)
  • pandasai/smart_dataframe/init.py (1 hunks)
  • pandasai/smart_datalake/init.py (11 hunks)
  • tests/connectors/test_sql.py (2 hunks)
  • tests/prompts/test_sql_prompt.py (1 hunks)
  • tests/skills/test_skills.py (2 hunks)
  • tests/test_codemanager.py (3 hunks)
  • tests/test_smartdataframe.py (1 hunks)
  • tests/test_smartdatalake.py (3 hunks)
Files skipped from review due to trivial changes (4)
  • examples/sql_direct_config.py
  • pandasai/schemas/df_config.py
  • pandasai/smart_dataframe/init.py
  • tests/test_smartdataframe.py
Additional comments: 36
pandasai/assets/prompt_templates/direct_sql_connector.tmpl (1)
  • 1-39: This hunk is a template for generating a Python function analyze_data(). It provides a structure for the function but does not contain any actual implementation. The function is expected to analyze data using provided dataframes and return a dictionary of the analysis results. The function takes no parameters and has a return type hint of dict. The function is incomplete and requires implementation. The hunk also includes the import of the pandas library.
pandasai/prompts/direct_sql_prompt.py (2)
  • 24-24: Ensure that the template file direct_sql_connector.tmpl exists at the specified path. If the path is relative, it should be relative to the location where the script is run.

  • 33-34: The setup method is setting a variable tables but it's not clear where this variable is used. Ensure that this variable is used somewhere in the class or its parent class.

pandasai/helpers/code_manager.py (4)
  • 28-40: The CodeExecutionContext class has been updated to include a new attribute _can_direct_sql and its corresponding parameter in the constructor. This attribute is used to determine whether direct SQL queries are allowed.

  • 46-48: The _can_direct_sql attribute is set in the constructor of the CodeExecutionContext class.

  • 55-64: > Note: This review was outside of the patch, so it was mapped to the patch with the greatest overlap. Original lines [50-60]

The can_direct_sql property has been added to provide access to the _can_direct_sql attribute.

  • 294-301: In the execute_code method, if context.can_direct_sql is true, it sets the execute_sql_query environment variable to a query execution function and returns the result of analyze_data(). Otherwise, it calls analyze_data() with the original dataframes. This allows for the execution of direct SQL queries when the flag is set.
tests/prompts/test_sql_prompt.py (1)
  • 51-113: The test case checks the to_string() method of the DirectSQLPrompt class. It sets various variables and checks if the output string is as expected. The test case seems to be correct and should work as expected.
pandasai/exceptions.py (1)
  • 155-173: class InvalidWorkspacePathError(Exception):
    """
    Raised when the environment variable of workspace exists but the path is invalid.
    Args:
    Exception (Exception): InvalidWorkspacePathError
    """

class InvalidConfigError(Exception):
"""
Raised when a configuration value is not applicable or invalid.
Args:
Exception (Exception): InvalidConfigError
"""

class MaliciousQueryError(Exception):
"""
Raised when a potentially harmful SQL query is detected.
Args:
Exception (Exception): MaliciousQueryError
"""



</blockquote></details>
<details><summary>tests/test_codemanager.py (2)</summary><blockquote>

* 75-79: The `exec_context` fixture has been updated to create an instance of `CodeExecutionContext` with a UUID and a `SkillsManager` object. This change is likely related to the new feature that allows direct SQL queries. Ensure that the `CodeExecutionContext` object is correctly initialized and used in the tests.


* 98-105: The `test_clean_code_remove_builtins` test case has been modified to set the `_can_direct_sql` attribute of `exec_context` to `False`. This change is likely related to the new feature that allows direct SQL queries. Ensure that the `_can_direct_sql` attribute is correctly set and used in the tests.



</blockquote></details>
<details><summary>pandasai/connectors/sql.py (3)</summary><blockquote>

* 5-9: Imports are correctly placed and the new exception `MaliciousQueryError` is imported correctly.



* 362-382: The `equals` method is correctly implemented to compare two instances of the same class based on their configuration properties.



* 399-403: The `execute_direct_sql_query` method correctly uses the `_is_sql_query_safe` method to check if a query is safe before executing it. If the query is not safe, it raises a `MaliciousQueryError`. This is a good practice for preventing SQL injection attacks.



</blockquote></details>
<details><summary>tests/skills/test_skills.py (2)</summary><blockquote>

* 71-74: The `exec_context` fixture now returns a `CodeExecutionContext` object with a UUID and a `SkillsManager` instance. This change is in line with the new feature that allows direct SQL queries. The `CodeExecutionContext` object is used to manage the execution context of the code, including the skills manager, which is responsible for managing the skills used in the code. This change is necessary to support the new feature.


* 318-320: The `test_code_exec_with_skills_no_use` test method no longer specifies the type of the `exec_context` parameter. This change is likely due to the update in the `exec_context` fixture. The test method now accepts any object as the `exec_context` parameter, which increases flexibility but may also increase the risk of type errors. It would be good to ensure that all calls to this method pass an object of the correct type.



</blockquote></details>
<details><summary>tests/test_smartdatalake.py (1)</summary><blockquote>

* 9-23: The new import statements are correctly placed and follow the PEP 8 style guide. The new classes `DirectSQLPrompt` and `GeneratePythonCodePrompt` are imported correctly.




</blockquote></details>
<details><summary>tests/connectors/test_sql.py (8)</summary><blockquote>

* 105-107: The test for the `fallback_name` property is correct and well implemented.


* 109-112: The test for the `_is_sql_query_safe` method with a safe query is correct and well implemented.


* 114-117: The test for the `_is_sql_query_safe` method with a malicious query is correct and well implemented.


* 119-125: The test for the `execute_direct_sql_query` method with a safe query is correct and well implemented.


* 127-133: The test for the `execute_direct_sql_query` method with a malicious query is correct and well implemented. It correctly checks for the `MaliciousQueryError` exception.


* 135-153: The test for the `equals` method with identical configs is correct and well implemented.


* 155-176: The test for the `equals` method with different configs is correct and well implemented.


* 178-196: The test for the `equals` method with a different connector is correct and well implemented.



</blockquote></details>
<details><summary>pandasai/smart_datalake/__init__.py (11)</summary><blockquote>

* 23-31: The import of `DirectSQLPrompt` is added. Ensure that the module exists and is accessible.


* 40-49: The import of `InvalidConfigError` is added. Ensure that the module exists and is accessible.


* 62-68: The `_can_direct_sql` attribute is added to the `SmartDatalake` class. This attribute is used to check if direct SQL can be used.


* 132-140: The `_validate_direct_sql` method is added to validate the direct SQL configuration. It checks if all dataframes belong to the same SQL connector type and raises an `InvalidConfigError` if they don't.


* 249-281: The `_get_chat_prompt` method is added to determine the chat prompt based on the direct SQL configuration. It returns a tuple with the key and the prompt.


* 306-312: The `_get_prompt` method is modified to accept a new argument `key` and use it to get the custom prompt from the configuration.


* 325-331: The `_get_prompt` method is modified to set the configuration for the prompt.


* 358-364: The `_get_cache_key` method is modified to append "direct_sql" to the cache key if direct SQL is enabled.


* 432-443: The `chat` method is modified to use the `_get_chat_prompt` method to determine the chat prompt.


* 470-478: The `chat` method is modified to pass the `_can_direct_sql` parameter to `CodeExecutionContext`.


* 593-599: The `correct_error` prompt is modified to use the `CorrectErrorPrompt` class as the default prompt.



</blockquote></details></blockquote></details>



</details>

Comment on lines 155 to +173
Args:
Exception (Exception): InvalidWorkspacePathError
"""


class InvalidConfigError(Exception):
"""
Raised when config value is not appliable
Args:
Exception (Exception): InvalidConfigError
"""


class MaliciousQueryError(Exception):
"""
Raise error if malicious query is generated
Args:
Exception (Excpetion): MaliciousQueryError
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new exceptions InvalidWorkspacePathError, InvalidConfigError, and MaliciousQueryError are well defined and follow the existing pattern of exception definitions in the file. However, the docstrings could be more descriptive to provide more context about when these exceptions are raised.

<!-- This is an auto-generated comment by CodeRabbit -->
---

<!-- suggestion_start -->
<details>
<summary>Commitable suggestion</summary>

> [!IMPORTANT]
> Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

```suggestion
    Args:
        Exception (Exception): InvalidWorkspacePathError
    """


class InvalidConfigError(Exception):
    """
    Raised when the config value is not applicable.
    
    Args:
        Exception (Exception): InvalidConfigError
    """


class MaliciousQueryError(Exception):
    """
    Raised when a malicious query is generated.
    
    Args:
        Exception (Exception): MaliciousQueryError
    """

Comment on lines +94 to +107
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equals method is comparing sensitive information like passwords. This could potentially lead to security issues if the method is used inappropriately. Consider comparing a hash of the password instead of the password itself. Also, consider adding a docstring to this method to explain its purpose and usage.

    def equals(self, other):
        if isinstance(other, self.__class__):
            return (
                self._config.dialect,
                self._config.account,
                self._config.username,
-               self._config.password,
+               hash(self._config.password),
            ) == (
                other._config.dialect,
                other._config.account,
                other._config.username,
-               other._config.password,
+               hash(other._config.password),
            )
        return False

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
def equals(self, other):
"""
Compare the current object with another object for equality.
Args:
other: The object to compare with.
Returns:
True if the objects are equal, False otherwise.
"""
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
hash(self._config.password),
) == (
other._config.dialect,
other._config.account,
other._config.username,
hash(other._config.password),
)
return False

Comment on lines +1 to +17
""" Prompt to explain code generation by the LLM
The previous conversation we had

<Conversation>
{conversation}
</Conversation>

Based on the last conversation you generated the following code:

<Code>
{code}
</Code>

Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring seems to be misplaced or irrelevant to the DirectSQLPrompt class. It should be updated to reflect the purpose and usage of the DirectSQLPrompt class.

Comment on lines 26 to 31
def _prepare_tables_data(self, tables):
tables_join = []
for table in tables:
table = f"<table>{table.head_csv}</table>"
tables_join.append(table)
return "\n".join(tables_join)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is joining table data with <table> tags. If the head_csv attribute of the table object is a string representation of CSV data, this could lead to malformed HTML if the CSV data contains special characters. Consider escaping special characters or using a library to generate the HTML.

Comment on lines 1 to 113
from pandasai.helpers.output_types import (
output_type_factory,
DefaultOutputType,
output_types_map,
)


class TestDirectSqlPrompt:
"""Unit tests for the correct error prompt class"""

@pytest.mark.parametrize(
"save_charts_path,output_type_hint,viz_library_type_hint",
[
(
"exports/charts",
DefaultOutputType().template_hint,
MatplotlibVizLibraryType().template_hint,
),
(
"custom/dir/for/charts",
DefaultOutputType().template_hint,
MatplotlibVizLibraryType().template_hint,
),
*[
(
"exports/charts",
output_type_factory(type_).template_hint,
viz_lib_type_factory(viz_type_).template_hint,
)
for type_ in output_types_map
for viz_type_ in viz_lib_map
],
],
)
def test_direct_sql_prompt_with_params(
self, save_charts_path, output_type_hint, viz_library_type_hint
):
"""Test that the __str__ method is implemented"""

llm = FakeLLM("plt.show()")
dfs = [
SmartDataframe(
pd.DataFrame({}),
config={"llm": llm},
)
]

prompt = DirectSQLPrompt(tables=dfs)
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "What is the correct code?")
prompt.set_var("output_type_hint", output_type_hint)
prompt.set_var("save_charts_path", save_charts_path)
prompt.set_var("viz_library_type", viz_library_type_hint)
prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
prompt_content = prompt_content.replace("\r\n", "\n")

assert (
prompt_content
== f'''You are provided with the following samples of sql tables data:

<Tables>
<table>
</table>
<Tables>

<conversation>
What is the correct code?
</conversation>

You are provided with following function that executes the sql query,
<Function>
def execute_sql_query(sql_query: str) -> pd.Dataframe
"""his method connect to the database executes the sql query and returns the dataframe"""
</Function>

This is the initial python function. Do not change the params.

```python
# TODO import all the dependencies required
import pandas as pd

def analyze_data() -> dict:
"""
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.)
2. Process: execute the query using execute method available to you which returns dataframe
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type_hint}
At the end, return a dictionary of:
{output_type_hint}
"""
```

Take a deep breath and reason step-by-step. Act as a senior data analyst.
In the answer, you must never write the "technical" names of the tables.
Based on the last message in the conversation:

- return the updated analyze_data function wrapped within `python `''' # noqa: E501
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case seems to be well written and covers a variety of scenarios by using parameterized inputs. However, the test case only checks the to_string() method of the DirectSQLPrompt class. It would be beneficial to add more test cases to cover other methods of the DirectSQLPrompt class to ensure full coverage.

Comment on lines 63 to +82
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equals method is a good addition for comparing two instances of the DatabricksConnector class. However, it's important to note that this method only checks for equality based on a subset of the instance's properties. If there are other properties that could affect the behavior of the instance, they should be included in this comparison. Also, consider renaming the method to __eq__ to follow Python's convention for equality comparison, which would allow you to use the == operator directly.

-    def equals(self, other):
+    def __eq__(self, other):

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
def __eq__(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False

Comment on lines +75 to +112
@pytest.fixture
@patch("pandasai.connectors.sql.create_engine", autospec=True)
def sql_connector(self, create_engine):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
username="your_username",
password="your_password",
host="your_host",
port=443,
database="your_database",
table="your_table",
where=[["column_name", "=", "value"]],
).dict()

# Create an instance of SQLConnector
return SQLConnector(self.config)

@pytest.fixture
@patch("pandasai.connectors.sql.create_engine", autospec=True)
def pgsql_connector(self, create_engine):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
username="your_username",
password="your_password",
host="your_host",
port=443,
database="your_database",
table="your_table",
where=[["column_name", "=", "value"]],
).dict()

# Create an instance of SQLConnector
return PostgreSQLConnector(self.config)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new fixtures sql_connector and pgsql_connector are correctly defined. However, the credentials for the SQL databases are hardcoded. This could be a security issue if this code is pushed to a public repository. Consider using environment variables or a secure method to store these credentials.

-            username="your_username",
-            password="your_password",
-            host="your_host",
+            username=os.getenv("DB_USERNAME"),
+            password=os.getenv("DB_PASSWORD"),
+            host=os.getenv("DB_HOST"),

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
@pytest.fixture
@patch("pandasai.connectors.sql.create_engine", autospec=True)
def sql_connector(self, create_engine):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
username="your_username",
password="your_password",
host="your_host",
port=443,
database="your_database",
table="your_table",
where=[["column_name", "=", "value"]],
).dict()
# Create an instance of SQLConnector
return SQLConnector(self.config)
@pytest.fixture
@patch("pandasai.connectors.sql.create_engine", autospec=True)
def pgsql_connector(self, create_engine):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
username="your_username",
password="your_password",
host="your_host",
port=443,
database="your_database",
table="your_table",
where=[["column_name", "=", "value"]],
).dict()
# Create an instance of SQLConnector
return PostgreSQLConnector(self.config)
@pytest.fixture
@patch("pandasai.connectors.sql.create_engine", autospec=True)
def sql_connector(self, create_engine):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
username=os.getenv("DB_USERNAME"),
password=os.getenv("DB_PASSWORD"),
host=os.getenv("DB_HOST"),
port=443,
database="your_database",
table="your_table",
where=[["column_name", "=", "value"]],
).dict()
# Create an instance of SQLConnector
return SQLConnector(self.config)
@pytest.fixture
@patch("pandasai.connectors.sql.create_engine", autospec=True)
def pgsql_connector(self, create_engine):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
username=os.getenv("DB_USERNAME"),
password=os.getenv("DB_PASSWORD"),
host=os.getenv("DB_HOST"),
port=443,
database="your_database",
table="your_table",
where=[["column_name", "=", "value"]],
).dict()
# Create an instance of SQLConnector
return PostgreSQLConnector(self.config)

Comment on lines 273 to +331
smart_datalake.chat("How many countries are in the dataframe?")
assert smart_datalake.last_answer == "Custom answer"
assert smart_datalake.last_reasoning == "Custom reasoning"

def test_get_chat_prompt(self, smart_datalake: SmartDatalake):
# Test case 1: direct_sql is True
smart_datalake._config.direct_sql = True
gen_key, gen_prompt = smart_datalake._get_chat_prompt()
expected_key = "direct_sql_prompt"
assert gen_key == expected_key
assert isinstance(gen_prompt, DirectSQLPrompt)

# Test case 2: direct_sql is False
smart_datalake._config.direct_sql = False
gen_key, gen_prompt = smart_datalake._get_chat_prompt()
expected_key = "generate_python_code"
assert gen_key == expected_key
assert isinstance(gen_prompt, GeneratePythonCodePrompt)

def test_validate_true_direct_sql_with_non_connector(self, llm, sample_df):
# raise exception with non connector
SmartDatalake(
[sample_df],
config={"llm": llm, "enable_cache": False, "direct_sql": True},
)

def test_validate_direct_sql_with_connector(self, llm, sql_connector):
# not exception is raised using single connector
SmartDatalake(
[sql_connector],
config={"llm": llm, "enable_cache": False, "direct_sql": True},
)

def test_validate_false_direct_sql_with_connector(self, llm, sql_connector):
# not exception is raised using single connector
SmartDatalake(
[sql_connector],
config={"llm": llm, "enable_cache": False, "direct_sql": False},
)

def test_validate_false_direct_sql_with_two_different_connector(
self, llm, sql_connector, pgsql_connector
):
# not exception is raised using single connector
SmartDatalake(
[sql_connector, pgsql_connector],
config={"llm": llm, "enable_cache": False, "direct_sql": False},
)

def test_validate_true_direct_sql_with_two_different_connector(
self, llm, sql_connector, pgsql_connector
):
# not exception is raised using single connector
# raise exception when two different connector
with pytest.raises(InvalidConfigError):
SmartDatalake(
[sql_connector, pgsql_connector],
config={"llm": llm, "enable_cache": False, "direct_sql": True},
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new test methods are correctly defined and cover the new functionality added to the SmartDatalake class. The tests cover both positive and negative scenarios and check for correct exceptions. However, the comments for the test cases are not very descriptive. Consider adding more detailed comments to explain what each test case is checking.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 2

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between b36f5fb and 36da80c.
Files selected for processing (2)
  • pandasai/connectors/sql.py (2 hunks)
  • pandasai/smart_datalake/init.py (11 hunks)
Additional comments: 11
pandasai/connectors/sql.py (1)
  • 5-9: The import of MaliciousQueryError is necessary for the new feature of executing direct SQL queries. This exception will be raised if a potentially harmful SQL query is detected.
pandasai/smart_datalake/__init__.py (10)
  • 23-31: The import statement for DirectSQLPrompt has been added. Ensure that the DirectSQLPrompt class is correctly implemented and that it is in the correct location for this import statement to work.

  • 40-49: The CodeExecutionContext import has been added. Ensure that the CodeExecutionContext class is correctly implemented and that it is in the correct location for this import statement to work.

  • 62-68: The _can_direct_sql attribute has been added to the SmartDatalake class. This attribute is used to determine whether direct SQL queries are allowed. Ensure that this attribute is correctly used throughout the class.

  • 132-140: The _validate_direct_sql method has been added to validate the direct SQL configuration. This method checks if all dataframes belong to the same SQL connector type and raises an InvalidConfigError if they do not. Ensure that this method is correctly used throughout the class.

  • 249-282: The _get_chat_prompt method has been added to determine the chat prompt based on the direct SQL configuration. This method returns a tuple containing the key for the prompt and the prompt itself. Ensure that this method is correctly used throughout the class.

  • 306-311: The _get_prompt method has been modified to include the _can_direct_sql flag. This flag is passed to the CodeExecutionContext constructor. Ensure that this flag is correctly used in the CodeExecutionContext class.

  • 325-331: The _get_prompt method has been modified to include the _can_direct_sql flag. This flag is passed to the CodeExecutionContext constructor. Ensure that this flag is correctly used in the CodeExecutionContext class.

  • 432-443: The chat method has been modified to use the _get_chat_prompt method and pass the tables to the prompt. Ensure that this method is correctly used throughout the class.

  • 470-478: The chat method has been modified to pass the _can_direct_sql flag to the CodeExecutionContext constructor. Ensure that this flag is correctly used in the CodeExecutionContext class.

  • 593-599: The _retry_run_code method has been modified to include the _can_direct_sql flag. This flag is passed to the CodeExecutionContext constructor. Ensure that this flag is correctly used in the CodeExecutionContext class.

Comment on lines +384 to +397
def _is_sql_query_safe(self, query: str):
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
]

return not any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
)
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _is_sql_query_safe method checks for potentially harmful SQL keywords in a query. This is a good security practice to prevent SQL injection attacks. However, it's important to note that this method only checks for a limited set of keywords and may not catch all possible malicious queries. Consider using a more comprehensive SQL injection prevention technique or library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai can you suggest more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem Sure, here are a few more suggestions to improve SQL query safety:

  1. Parameterized Queries: This is one of the most effective ways to prevent SQL injection. Instead of constructing SQL queries with string concatenation, use parameterized queries. This ensures that user-supplied input is always treated as literal values and not part of the SQL command.

  2. ORM (Object-Relational Mapping) Libraries: Using an ORM library can abstract the SQL layer, reducing the risk of SQL injection attacks. These libraries generate SQL queries on your behalf and are generally designed to avoid SQL injection.

  3. Least Privilege Principle: Limit the permissions of the database account used by the application. Don't use a database root account whose compromise could lead to full database access.

  4. Regular Expressions: While the current method uses regular expressions to check for malicious keywords, it could be expanded to include more patterns. However, this method should not be relied upon solely as it can be bypassed with various techniques.

  5. Web Application Firewalls (WAFs): These can help to filter out malicious data and provide an additional layer of security.

  6. Input Validation: Validate input data thoroughly. Use both whitelist (only allow known good patterns) and blacklist (block known bad patterns) validation.

Remember, no single technique is foolproof. A combination of these methods will provide a more robust defense against SQL injection attacks.

Comment on lines +399 to +403
def execute_direct_sql_query(self, sql_query):
if not self._is_sql_query_safe(sql_query):
raise MaliciousQueryError("Malicious query is generated in code")

return pd.read_sql(sql_query, self._connection)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The execute_direct_sql_query method executes a SQL query directly. It first checks if the query is safe using the _is_sql_query_safe method and raises a MaliciousQueryError if the query is determined to be malicious. This is a good practice for security. However, as mentioned above, the _is_sql_query_safe method may not catch all possible malicious queries.

Copy link
Collaborator

@gventuri gventuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ArslanSaleem, tried it out, but it seems to use the wrong table name.

Example: "SELECT COUNT(*) AS total FROM table2"

The table is not called table2, and I suppose the problem is we don't pass the table name at all, can you confirm?

@gventuri gventuri changed the base branch from main to release/v1.5 November 6, 2023 22:25
@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2023

Codecov Report

❗ No coverage uploaded for pull request base (release/v1.5@3fa5625). Click here to learn what that means.
The diff coverage is n/a.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

@@               Coverage Diff               @@
##             release/v1.5     #731   +/-   ##
===============================================
  Coverage                ?   85.19%           
===============================================
  Files                   ?       86           
  Lines                   ?     3817           
  Branches                ?        0           
===============================================
  Hits                    ?     3252           
  Misses                  ?      565           
  Partials                ?        0           

📣 Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today!

@gventuri gventuri merged commit e3c6b79 into release/v1.5 Nov 7, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants