Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(directSqlPrompt): use connector directly if flag is set #731

Merged
merged 16 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/sql_direct_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Example of using PandasAI with a CSV file."""

from pandasai import SmartDatalake
from pandasai.llm import OpenAI
from pandasai.connectors import PostgreSQLConnector
from pandasai.smart_dataframe import SmartDataframe


# With a PostgreSQL database
order = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "orders",
}
)

order_details = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "order_details",
}
)

products = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "products",
}
)


llm = OpenAI("OPEN_API_KEY")


order_details_smart_df = SmartDataframe(
order_details,
config={"llm": llm, "direct_sql": True},
description="Contain user order details",
)


df = SmartDatalake(
[order_details_smart_df, order, products],
config={"llm": llm, "direct_sql": True},
)
response = df.chat("return orders with count of distinct products")
print(response)
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ nav:
- Documents Building: building_docs.md
- License: license.md
extra:
version: "1.4.3"
version: "1.4.4"

plugins:
- search
- mkdocstrings:
Expand Down
5 changes: 5 additions & 0 deletions pandasai/assets/prompt_templates/default_instructions.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type}
39 changes: 39 additions & 0 deletions pandasai/assets/prompt_templates/direct_sql_connector.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
You are provided with the following samples of sql tables data:

<Tables>
{tables}
<Tables>

<conversation>
{conversation}
</conversation>

You are provided with following function that executes the sql query,
<Function>
def execute_sql_query(sql_query: str) -> pd.Dataframe
"""his method connect to the database executes the sql query and returns the dataframe"""
</Function>

This is the initial python function. Do not change the params.

```python
# TODO import all the dependencies required
import pandas as pd

def analyze_data() -> dict:
"""
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.)
2. Process: execute the query using execute method available to you which returns dataframe
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type}
At the end, return a dictionary of:
{output_type_hint}
"""
```

Take a deep breath and reason step-by-step. Act as a senior data analyst.
In the answer, you must never write the "technical" names of the tables.
Based on the last message in the conversation:

- return the updated analyze_data function wrapped within `python `
2 changes: 0 additions & 2 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

{viz_library_type}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
{current_code}
Expand Down
1 change: 1 addition & 0 deletions pandasai/assets/prompt_templates/viz_library.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file.
17 changes: 17 additions & 0 deletions pandasai/connectors/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,20 @@ def __repr__(self):
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
Comment on lines 63 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equals method is a good addition for comparing two instances of the DatabricksConnector class. However, it's important to note that this method only checks for equality based on a subset of the instance's properties. If there are other properties that could affect the behavior of the instance, they should be included in this comparison. Also, consider renaming the method to __eq__ to follow Python's convention for equality comparison, which would allow you to use the == operator directly.

-    def equals(self, other):
+    def __eq__(self, other):

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
def __eq__(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False

15 changes: 15 additions & 0 deletions pandasai/connectors/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,18 @@ def __repr__(self):
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
Comment on lines +94 to +107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equals method is comparing sensitive information like passwords. This could potentially lead to security issues if the method is used inappropriately. Consider comparing a hash of the password instead of the password itself. Also, consider adding a docstring to this method to explain its purpose and usage.

    def equals(self, other):
        if isinstance(other, self.__class__):
            return (
                self._config.dialect,
                self._config.account,
                self._config.username,
-               self._config.password,
+               hash(self._config.password),
            ) == (
                other._config.dialect,
                other._config.account,
                other._config.username,
-               other._config.password,
+               hash(other._config.password),
            )
        return False

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
def equals(self, other):
"""
Compare the current object with another object for equality.
Args:
other: The object to compare with.
Returns:
True if the objects are equal, False otherwise.
"""
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
hash(self._config.password),
) == (
other._config.dialect,
other._config.account,
other._config.username,
hash(other._config.password),
)
return False

42 changes: 42 additions & 0 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
import os
import pandas as pd

from pandasai.exceptions import MaliciousQueryError
from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig
from .base import BaseConnectorConfig
from sqlalchemy import create_engine, text, select, asc
Expand Down Expand Up @@ -360,6 +362,46 @@ def column_hash(self):
def fallback_name(self):
return self._config.table

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.driver,
self._config.host,
self._config.port,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.driver,
other._config.host,
other._config.port,
other._config.username,
other._config.password,
)
return False

def _is_sql_query_safe(self, query: str):
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
]

return not any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
)
Comment on lines +384 to +397
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _is_sql_query_safe method checks for potentially harmful SQL keywords in a query. This is a good security practice to prevent SQL injection attacks. However, it's important to note that this method only checks for a limited set of keywords and may not catch all possible malicious queries. Consider using a more comprehensive SQL injection prevention technique or library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai can you suggest more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem Sure, here are a few more suggestions to improve SQL query safety:

  1. Parameterized Queries: This is one of the most effective ways to prevent SQL injection. Instead of constructing SQL queries with string concatenation, use parameterized queries. This ensures that user-supplied input is always treated as literal values and not part of the SQL command.

  2. ORM (Object-Relational Mapping) Libraries: Using an ORM library can abstract the SQL layer, reducing the risk of SQL injection attacks. These libraries generate SQL queries on your behalf and are generally designed to avoid SQL injection.

  3. Least Privilege Principle: Limit the permissions of the database account used by the application. Don't use a database root account whose compromise could lead to full database access.

  4. Regular Expressions: While the current method uses regular expressions to check for malicious keywords, it could be expanded to include more patterns. However, this method should not be relied upon solely as it can be bypassed with various techniques.

  5. Web Application Firewalls (WAFs): These can help to filter out malicious data and provide an additional layer of security.

  6. Input Validation: Validate input data thoroughly. Use both whitelist (only allow known good patterns) and blacklist (block known bad patterns) validation.

Remember, no single technique is foolproof. A combination of these methods will provide a more robust defense against SQL injection attacks.


def execute_direct_sql_query(self, sql_query):
if not self._is_sql_query_safe(sql_query):
raise MaliciousQueryError("Malicious query is generated in code")

return pd.read_sql(sql_query, self._connection)
Comment on lines +399 to +403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The execute_direct_sql_query method executes a SQL query directly. It first checks if the query is safe using the _is_sql_query_safe method and raises a MaliciousQueryError if the query is determined to be malicious. This is a good practice for security. However, as mentioned above, the _is_sql_query_safe method may not catch all possible malicious queries.



class SqliteConnector(SQLConnector):
"""
Expand Down
17 changes: 17 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,27 @@ class UnSupportedLogicUnit(Exception):
Exception (Exception): UnSupportedLogicUnit
"""


class InvalidWorkspacePathError(Exception):
"""
Raised when the environment variable of workspace exist but path is invalid

Args:
Exception (Exception): InvalidWorkspacePathError
"""


class InvalidConfigError(Exception):
"""
Raised when config value is not appliable
Args:
Exception (Exception): InvalidConfigError
"""


class MaliciousQueryError(Exception):
"""
Raise error if malicious query is generated
Args:
Exception (Excpetion): MaliciousQueryError
"""
Comment on lines 163 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new exceptions InvalidWorkspacePathError, InvalidConfigError, and MaliciousQueryError are well defined and follow the existing pattern of exception definitions in the file. However, the docstrings could be more descriptive to provide more context about when these exceptions are raised.

<!-- This is an auto-generated comment by CodeRabbit -->
---

<!-- suggestion_start -->
<details>
<summary>Commitable suggestion</summary>

> [!IMPORTANT]
> Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

```suggestion
    Args:
        Exception (Exception): InvalidWorkspacePathError
    """


class InvalidConfigError(Exception):
    """
    Raised when the config value is not applicable.
    
    Args:
        Exception (Exception): InvalidConfigError
    """


class MaliciousQueryError(Exception):
    """
    Raised when a malicious query is generated.
    
    Args:
        Exception (Exception): MaliciousQueryError
    """

17 changes: 16 additions & 1 deletion pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@

class CodeExecutionContext:
_prompt_id: uuid.UUID = None
_can_direct_sql: bool = False
_skills_manager: SkillsManager = None

def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager):
def __init__(
self,
prompt_id: uuid.UUID,
skills_manager: SkillsManager,
_can_direct_sql: bool = False,
):
"""
Additional Context for code execution
Args:
Expand All @@ -39,6 +45,7 @@ def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager):
"""
self._skills_manager = skills_manager
self._prompt_id = prompt_id
self._can_direct_sql = _can_direct_sql

@property
def prompt_id(self):
Expand All @@ -48,6 +55,10 @@ def prompt_id(self):
def skills_manager(self):
return self._skills_manager

@property
def can_direct_sql(self):
return self._can_direct_sql


class CodeManager:
_dfs: List
Expand Down Expand Up @@ -283,6 +294,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any:

analyze_data = environment.get("analyze_data")

if context.can_direct_sql:
environment["execute_sql_query"] = self._dfs[0].get_query_exec_func()
return analyze_data()

return analyze_data(self._get_originals(dfs))

def _get_samples(self, dfs):
Expand Down
5 changes: 2 additions & 3 deletions pandasai/helpers/viz_library_types/_viz_library_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import abstractmethod, ABC
from typing import Any, Iterable
from pandasai.prompts.generate_python_code import VizLibraryPrompt


class BaseVizLibraryType(ABC):
@property
def template_hint(self) -> str:
return f"""When a user requests to create a chart, utilize the Python
{self.name} library to generate high-quality graphics that will be saved
directly to a file."""
return VizLibraryPrompt(library=self.name)

@property
@abstractmethod
Expand Down
10 changes: 5 additions & 5 deletions pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
In order to better handle the instructions, this prompt module is written.
"""
from abc import ABC, abstractmethod
import string


class AbstractPrompt(ABC):
Expand Down Expand Up @@ -92,12 +93,11 @@ def to_string(self):
prompt_args = {}
for key, value in self._args.items():
if isinstance(value, AbstractPrompt):
args = [
arg[1] for arg in string.Formatter().parse(value.template) if arg[1]
]
value.set_vars(
{
k: v
for k, v in self._args.items()
if k != key and not isinstance(v, AbstractPrompt)
}
{k: v for k, v in self._args.items() if k != key and k in args}
)
prompt_args[key] = value.to_string()
else:
Expand Down
40 changes: 40 additions & 0 deletions pandasai/prompts/direct_sql_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
""" Prompt to explain code generation by the LLM
The previous conversation we had

<Conversation>
{conversation}
</Conversation>

Based on the last conversation you generated the following code:

<Code>
{code}
</Code>

Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?

"""
Comment on lines +1 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring seems to be misplaced or irrelevant to the DirectSQLPrompt class. It should be updated to reflect the purpose and usage of the DirectSQLPrompt class.

from .file_based_prompt import FileBasedPrompt


class DirectSQLPrompt(FileBasedPrompt):
"""Prompt to explain code generation by the LLM"""

_path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl"

def _prepare_tables_data(self, tables):
tables_join = []
for table in tables:
table_description_tag = (
f' description="{table.table_description}"'
if table.table_description is not None
else ""
)
table_head_tag = f'<table name="{table.table_name}"{table_description_tag}>'
table = f"{table_head_tag}\n{table.head_csv}\n</table>"
tables_join.append(table)
return "\n\n".join(tables_join)

def setup(self, tables) -> None:
self.set_var("tables", self._prepare_tables_data(tables))
Loading