diff --git a/extensions/ee/connectors/bigquery/pandasai_bigquery/google_big_query.py b/extensions/ee/connectors/bigquery/pandasai_bigquery/google_big_query.py deleted file mode 100644 index 7c5a5d0b5..000000000 --- a/extensions/ee/connectors/bigquery/pandasai_bigquery/google_big_query.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Google Big Query connector is used to connect to dataset from -google big query api. -""" - -from typing import Union, Optional -from pydantic import Field - -from sqlalchemy import create_engine - -from pandasai.exceptions import InvalidConfigError - -from pandasai.connectors.base import BaseConnectorConfig -from pandasai_sql.sql import SQLBaseConnectorConfig, SQLConnector - - -class GoogleBigQueryConnectorConfig(SQLBaseConnectorConfig): - """ - Connector configuration for big query. - """ - - dialect: str = "bigquery" - database: str - table: str - projectID: str - credentials_path: Optional[str] = None - credentials_base64: Optional[str] = Field(default=None) - - class Config: - extra = "allow" - - -class GoogleBigQueryConnector(SQLConnector): - """ - GoogleBigQuery Connectors are used to connect to BigQuery Data Cloud. - """ - - def __init__(self, config: Union[GoogleBigQueryConnectorConfig, dict]): - """ - Initialize the GoogleBigQuery connector with the given configuration. - - Args: - config (ConnectorConfig): The config for the GoogleBigQuery connector. - """ - config["dialect"] = "bigquery" - if isinstance(config, dict): - env_vars = { - "database": "BIG_QUERY_DATABASE", - "credentials_path": "KEYFILE_PATH", - "projectID": "PROJECT_ID", - } - config = self._populate_config_from_env(config, env_vars) - - if "credentials_base64" not in config and "credentials_path" not in config: - raise InvalidConfigError( - "credentials_path or credentials_base64 is needed to connect" - ) - - super().__init__(config) - - def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): - return GoogleBigQueryConnectorConfig(**config) - - def _init_connection(self, config: GoogleBigQueryConnectorConfig): - """ - Initialize Database Connection - - Args: - config (GoogleBigQueryConnectorConfig): Configurations to load database - - """ - if config.credentials_path: - self._engine = create_engine( - f"{config.dialect}://{config.projectID}/{config.database}", - credentials_path=config.credentials_path, - ) - else: - self._engine = create_engine( - f"{config.dialect}://{config.projectID}/{config.database}?credentials_base64={config.credentials_base64}" - ) - - self._connection = self._engine.connect() - - def __repr__(self): - """ - Return the string representation of the Google big query connector. - - Returns: - str: The string representation of the Google big query connector. - """ - return ( - f"<{self.__class__.__name__} dialect={self.config.dialect} " - f"projectid= {self.config.projectID} database={self.config.database} >" - ) - - def equals(self, other): - if isinstance(other, self.__class__): - return ( - self.config.dialect, - self.config.driver, - self.config.credentials_path, - self.config.credentials_base64, - self.config.database, - self.config.projectID, - ) == ( - other.config.dialect, - other.config.driver, - other.config.credentials_path, - other.config.credentials_base64, - other.config.database, - other.config.projectID, - ) - return False diff --git a/extensions/ee/connectors/bigquery/tests/test_google_big_query.py b/extensions/ee/connectors/bigquery/tests/test_google_big_query.py deleted file mode 100644 index 0516ebe50..000000000 --- a/extensions/ee/connectors/bigquery/tests/test_google_big_query.py +++ /dev/null @@ -1,325 +0,0 @@ -import unittest -from unittest.mock import Mock, patch -import pandas as pd -import base64 -import json -from pandasai_sql.sql import PostgreSQLConnector -from pandasai_bigquery.google_big_query import ( - GoogleBigQueryConnector, - GoogleBigQueryConnectorConfig, -) -from pandasai_bigquery import load_from_bigquery - - -def get_mock_credentials_base64(): - mock_service_account = { - "type": "service_account", - "project_id": "test", - "private_key_id": "test", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC9QFi6I8lJ1GiX\n1JUoaXPE0UZlWl1nCtJ4kGmZOJ7JrxFyB7O+Qw9G+2KFGe8FVqmvX4RX8Y0IPWSM\nrKqJ6B9H8PIxGZPKjH7t8qMUDGXkuXS7dWCHgKBqPnY+1V4yLwfwB3UC2HgoYtM8\n1YhZdV5ExHYc3nqHGxHqFu1YgIBVLHMA4hJ0tS6SGzEBR6mP8nqxw6O7M0zZ6Xpe\nMp7hCZ+YDqrHp/uBt1IgXlktYPKQJ/Z6WY4Jv9KQIE1B1h96mY0QyO7c0CJYzHPP\nM2mHmX9wXtfmE9NbRFRHzL9Xt4tXGQHdmzJ4fkRt3UUYv2uH1ZEIhtQwM+peH2Nt\nkXeK1vYnAgMBAAECggEABD19YhqDhOHjPEX8wqpqBmdGhcKB0jHrxXWW7YtB6iq5\n1h1xJuYHUVAhytbpLWpAI1ZG5r7Vz3Z+MXn+RFTEyPo8GxrYxXzjxHR9XyHEVzgd\nxjuWEZBFxR3Qnl3TFB4f5YHXmoN7K+HjrNrwz9+thJA9p9VMFS0ePxgWVLQNQm4U\nXvmyJVbXhkYwMDgkv6i7U9RhNGJbBKx/VR3Enz8hQhZRHGPqYJiHXHQSQH4/kh4M\n5A7BxqY7W1n1Ot1jX3kCQYu0Qs5Ky6iQKRos3VMQGqZAstA7QYPEwCbXb/3AXheY\nGZBb/1ZXU5Q0tb5Qz7njfqOqnN0DVCK7UxDGDjpSYQKBgQD4u2yx3ux3QqHf8Zl5\nLhbRyHto0LL4Rn+I3hXEUL4s/qJxzXAM5KNxv7H96hgCXZWXnzuKr4JSwHGj8qO5\nS8YUgxjwrVTv0jscY0Rk7xcKGRJGH/+1+9qh4qoN2/PC8m7jv5YEfqPR77T4h1yQ\nJGE7YkeHnxPjwHDhphTPmmGsYQKBgQDCwXxUVs7yL7d0cP6u0gULm/HBc5ERQvwB\nO7ywXxFw7EvGhxh6YEF3MDFVh0yjZOYj0qB3OKXuAEzIYYQGc+oVm/6TUJRjqX8b\nG7HTXUy8WQr0XKYE4xHVPYZT7KD5mKOgGkPTjX5GzG5oQHJfVx5k6R0YvXb/mpIj\nEUF2qQ+thwKBgBXQR6Q6VFgqGWUzEA6ZS3hEL5mc4QZDhH4qK4HAs3TCXH9qphY1\n/0W8GH5GWbfPGQ0qHsz1y+6QYGb6R2tV4F+lrGE0D3Y7iI6aCkS5qjz+Qj3YHrVD\nFGb6I+xv2k0PQtXxjudT1ydx1QmLxB9QQxJz8seBLPuSx0+Jk5jv2ZchAoGAam0/\n8jE4zQzjYMzQDEhZPUPFv6O6MOLQaQZ3aGRMeHGj6PUV1R2541ARqTw0Z8oXx1yd\nW1JxUwx3yL1ivI3LEHZYJh0pm4UJk+UZwwE5C8tdDw+7jO0PZ1N8K4Rt3CP9eZy8\n9DLghfwJlBjFVzv0QNJx+BtLRpmhQwV3K8PSTk8CgYEAjq5EX5LX7gKhPBZkjqHp\nE7Jc54+13HRJZvPmWkFj7ZiAAmx7+QnEHCpVzKDRk7U+yhZwzWqyHNHHvhPQQyfG\nZZZQVLEv6p2Bs2BXnQZvjRKB8K7Bpq5GCZVXPSn8YTVVHNBCnvqYR3oGrQhQ0eGH\nL/dA50N6oYMMPkwX0JyRUHo=\n-----END PRIVATE KEY-----\n", - "client_email": "test@test.iam.gserviceaccount.com", - "client_id": "test", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test%40test.iam.gserviceaccount.com", - } - return base64.b64encode(json.dumps(mock_service_account).encode()).decode() - - -class TestGoogleBigQueryConnector(unittest.TestCase): - @patch( - "pandasai_bigquery.google_big_query.create_engine", - autospec=True, - ) - def setUp(self, mock_create_engine) -> None: - self.mock_engine = Mock() - self.mock_connection = Mock() - self.mock_engine.connect.return_value = self.mock_connection - mock_create_engine.return_value = self.mock_engine - - self.config = GoogleBigQueryConnectorConfig( - dialect="bigquery", - database="database", - table="yourtable", - credentials_base64=get_mock_credentials_base64(), # base64 encoded mock service account with PEM key - projectID="project_id", - ).dict() - - self.connector = GoogleBigQueryConnector(self.config) - - @patch( - "pandasai_bigquery.google_big_query.GoogleBigQueryConnector._load_connector_config" - ) - @patch( - "pandasai_bigquery.google_big_query.GoogleBigQueryConnector._init_connection" - ) - def test_constructor_and_properties( - self, mock_load_connector_config, mock_init_connection - ): - # Test constructor and properties - self.assertEqual(self.connector.config.model_dump(), self.config) - self.assertEqual(self.connector._engine, self.mock_engine) - self.assertEqual(self.connector._connection, self.mock_connection) - self.assertEqual(self.connector._cache_interval, 600) - GoogleBigQueryConnector(self.config) - mock_load_connector_config.assert_called() - mock_init_connection.assert_called() - - @patch( - "pandasai_bigquery.google_big_query.create_engine", - autospec=True, - ) - def test_constructor_and_properties_with_base64_string(self, mock_create_engine): - self.mock_engine = Mock() - self.mock_connection = Mock() - self.mock_engine.connect.return_value = self.mock_connection - mock_create_engine.return_value = self.mock_engine - - self.config = GoogleBigQueryConnectorConfig( - dialect="bigquery", - database="database", - table="yourtable", - credentials_base64=get_mock_credentials_base64(), # base64 encoded mock service account with PEM key - projectID="project_id", - ).dict() - - self.connector = GoogleBigQueryConnector(self.config) - mock_create_engine.assert_called_with( - "bigquery://project_id/database?credentials_base64=" - + get_mock_credentials_base64() - ) - - def test_repr_method(self): - # Test __repr__ method - expected_repr = ( - "" - ) - self.assertEqual(repr(self.connector), expected_repr) - - @patch("pandasai_sql.sql.pd.read_sql", autospec=True) - def test_head_method(self, mock_read_sql): - expected_data = pd.DataFrame({"Column1": [1, 2, 3], "Column2": [4, 5, 6]}) - mock_read_sql.return_value = expected_data - head_data = self.connector.head() - pd.testing.assert_frame_equal(head_data, expected_data) - - def test_rows_count_property(self): - # Test rows_count property - self.connector._rows_count = None - self.mock_connection.execute.return_value.fetchone.return_value = ( - 50, - ) # Sample rows count - rows_count = self.connector.rows_count - self.assertEqual(rows_count, 50) - - def test_columns_count_property(self): - # Test columns_count property - self.connector._columns_count = None - mock_df = Mock() - mock_df.columns = ["Column1", "Column2"] - self.connector.head = Mock(return_value=mock_df) - columns_count = self.connector.columns_count - self.assertEqual(columns_count, 2) - - def test_column_hash_property(self): - # Test column_hash property - mock_df = Mock() - mock_df.columns = ["Column1", "Column2"] - self.connector.head = Mock(return_value=mock_df) - column_hash = self.connector.column_hash - self.assertIsNotNone(column_hash) - self.assertEqual( - column_hash, - "0d045cff164deef81e24b0ed165b7c9c2789789f013902115316cde9d214fe63", - ) - - def test_fallback_name_property(self): - # Test fallback_name property - fallback_name = self.connector.fallback_name - self.assertEqual(fallback_name, "yourtable") - - @patch( - "extensions.ee.connectors.bigquery.pandasai_bigquery.google_big_query.create_engine", - autospec=True, - ) - def test_constructor_and_properties_equal_func(self, mock_create_engine): - self.mock_engine = Mock() - self.mock_connection = Mock() - self.mock_engine.connect.return_value = self.mock_connection - mock_create_engine.return_value = self.mock_engine - - self.config = GoogleBigQueryConnectorConfig( - dialect="bigquery", - database="database", - table="yourtable", - credentials_base64=get_mock_credentials_base64(), # base64 encoded mock service account with PEM key - projectID="project_id", - ).dict() - - self.connector = GoogleBigQueryConnector(self.config) - connector_2 = GoogleBigQueryConnector(self.config) - - assert self.connector.equals(connector_2) - - @patch( - "extensions.ee.connectors.bigquery.pandasai_bigquery.google_big_query.create_engine", - autospec=True, - ) - def test_constructor_and_properties_not_equal_func(self, mock_create_engine): - self.mock_engine = Mock() - self.mock_connection = Mock() - self.mock_engine.connect.return_value = self.mock_connection - mock_create_engine.return_value = self.mock_engine - - self.config = GoogleBigQueryConnectorConfig( - dialect="bigquery", - database="database", - table="yourtable", - credentials_base64=get_mock_credentials_base64(), # base64 encoded mock service account with PEM key - projectID="project_id", - ).dict() - - config2 = GoogleBigQueryConnectorConfig( - dialect="bigquery", - database="database2", - table="yourtable", - credentials_base64=get_mock_credentials_base64(), # base64 encoded mock service account with PEM key - projectID="project_id", - ).dict() - - self.connector = GoogleBigQueryConnector(self.config) - connector_2 = GoogleBigQueryConnector(config2) - - assert not self.connector.equals(connector_2) - - @patch( - "extensions.ee.connectors.bigquery.pandasai_bigquery.google_big_query.create_engine", - autospec=True, - ) - @patch("pandasai_sql.SQLConnector._init_connection") - def test_constructor_and_properties_different_type( - self, mock_connection, mock_create_engine - ): - self.mock_engine = Mock() - self.mock_connection = Mock() - self.mock_engine.connect.return_value = self.mock_connection - mock_create_engine.return_value = self.mock_engine - - self.config = GoogleBigQueryConnectorConfig( - dialect="bigquery", - database="database", - table="yourtable", - credentials_base64=get_mock_credentials_base64(), # base64 encoded mock service account with PEM key - projectID="project_id", - ).dict() - - config = { - "username": "your_username_differ", - "password": "your_password", - "host": "your_host", - "port": 443, - "database": "your_database", - "table": "your_table", - "where": [["column_name", "=", "value"]], - } - - # Create an instance of SQLConnector - connector_2 = PostgreSQLConnector(config) - - self.connector = GoogleBigQueryConnector(self.config) - - assert not self.connector.equals(connector_2) - - -class TestLoadFromBigQuery(unittest.TestCase): - @patch("extensions.ee.connectors.bigquery.pandasai_bigquery.bigquery.Client") - def test_load_from_bigquery(self, mock_client): - # Mock the connection info - connection_info = { - "project_id": "test_project", - "credentials": "test_credentials", - } - - # Mock the query - query = "SELECT * FROM test_table" - - # Mock the query job and its result - mock_query_job = Mock() - mock_result = [{"col1": 1, "col2": "a"}, {"col1": 2, "col2": "b"}] - mock_query_job.result.return_value = mock_result - - # Set up the mock client to return our mock query job - mock_client_instance = mock_client.return_value - mock_client_instance.query.return_value = mock_query_job - - # Call the function - result = load_from_bigquery(connection_info, query) - - # Assert that the client was created with the correct arguments - mock_client.assert_called_once_with( - project=connection_info["project_id"], - credentials=connection_info["credentials"], - ) - - # Assert that the query was executed - mock_client_instance.query.assert_called_once_with(query) - - # Assert that the result is a pandas DataFrame with the expected data - expected_df = pd.DataFrame(mock_result) - pd.testing.assert_frame_equal(result, expected_df) - - @patch("extensions.ee.connectors.bigquery.pandasai_bigquery.bigquery.Client") - def test_load_from_bigquery_without_credentials(self, mock_client): - # Mock the connection info without credentials - connection_info = {"project_id": "test_project"} - - query = "SELECT * FROM test_table" - - mock_query_job = Mock() - mock_result = [{"col1": 1, "col2": "a"}, {"col1": 2, "col2": "b"}] - mock_query_job.result.return_value = mock_result - - mock_client_instance = mock_client.return_value - mock_client_instance.query.return_value = mock_query_job - - result = load_from_bigquery(connection_info, query) - - # Assert that the client was created with the correct arguments - mock_client.assert_called_once_with( - project=connection_info["project_id"], credentials=None - ) - - mock_client_instance.query.assert_called_once_with(query) - - expected_df = pd.DataFrame(mock_result) - pd.testing.assert_frame_equal(result, expected_df) - - @patch("extensions.ee.connectors.bigquery.pandasai_bigquery.bigquery.Client") - def test_load_from_bigquery_empty_result(self, mock_client): - connection_info = { - "project_id": "test_project", - "credentials": "test_credentials", - } - - query = "SELECT * FROM empty_table" - - mock_query_job = Mock() - mock_result = [] # Empty result - mock_query_job.result.return_value = mock_result - - mock_client_instance = mock_client.return_value - mock_client_instance.query.return_value = mock_query_job - - result = load_from_bigquery(connection_info, query) - - mock_client.assert_called_once_with( - project=connection_info["project_id"], - credentials=connection_info["credentials"], - ) - - mock_client_instance.query.assert_called_once_with(query) - - expected_df = pd.DataFrame() # Empty DataFrame - pd.testing.assert_frame_equal(result, expected_df) diff --git a/extensions/llms/bedrock/pandasai_bedrock/base.py b/extensions/llms/bedrock/pandasai_bedrock/base.py index 28c1485b7..606248747 100644 --- a/extensions/llms/bedrock/pandasai_bedrock/base.py +++ b/extensions/llms/bedrock/pandasai_bedrock/base.py @@ -3,16 +3,17 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Optional +from pandasai.chat.prompts.base import BasePrompt from pandasai.helpers.memory import Memory from pandasai.exceptions import ( MethodNotImplementedError, ) -from pandasai.prompts.base import BasePrompt from pandasai.llm.base import LLM + if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState class BaseGoogle(LLM): @@ -77,13 +78,13 @@ def _generate_text(self, prompt: str, memory: Optional[Memory] = None) -> str: """ raise MethodNotImplementedError("method has not been implemented") - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: """ Call the Google LLM. Args: instruction (BasePrompt): Instruction to pass. - context (PipelineContext): Pass PipelineContext. + context (AgentState): Pass AgentState. Returns: str: LLM response. diff --git a/extensions/llms/bedrock/pandasai_bedrock/claude.py b/extensions/llms/bedrock/pandasai_bedrock/claude.py index ce8fe0ca8..cc15ee8d4 100644 --- a/extensions/llms/bedrock/pandasai_bedrock/claude.py +++ b/extensions/llms/bedrock/pandasai_bedrock/claude.py @@ -3,13 +3,13 @@ import json from typing import TYPE_CHECKING, Any, Dict, Optional +from pandasai.chat.prompts.base import BasePrompt from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError from pandasai.helpers import load_dotenv -from pandasai.prompts.base import BasePrompt from pandasai.llm.base import LLM if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState load_dotenv() @@ -78,7 +78,7 @@ def _default_params(self) -> Dict[str, Any]: "stop_sequences": self.stop_sequences, } - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: prompt = instruction.to_string() memory = context.memory if context else None diff --git a/extensions/llms/bedrock/tests/test_bedrock_claude.py b/extensions/llms/bedrock/tests/test_bedrock_claude.py index 38627b2d1..614e5153e 100644 --- a/extensions/llms/bedrock/tests/test_bedrock_claude.py +++ b/extensions/llms/bedrock/tests/test_bedrock_claude.py @@ -1,13 +1,14 @@ """Unit tests for the openai LLM class""" + import io import json from unittest.mock import MagicMock import pytest +from pandasai.chat.prompts.base import BasePrompt from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError from extensions.llms.bedrock.pandasai_bedrock.claude import BedrockClaude -from pandasai.prompts import BasePrompt class MockBedrockRuntimeClient: diff --git a/extensions/llms/google/pandasai_google/base.py b/extensions/llms/google/pandasai_google/base.py index 91e57264f..606248747 100644 --- a/extensions/llms/google/pandasai_google/base.py +++ b/extensions/llms/google/pandasai_google/base.py @@ -3,17 +3,17 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Optional +from pandasai.chat.prompts.base import BasePrompt from pandasai.helpers.memory import Memory from pandasai.exceptions import ( MethodNotImplementedError, ) -from pandasai.prompts.base import BasePrompt from pandasai.llm.base import LLM -if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext +if TYPE_CHECKING: + from pandasai.agent.state import AgentState class BaseGoogle(LLM): @@ -78,13 +78,13 @@ def _generate_text(self, prompt: str, memory: Optional[Memory] = None) -> str: """ raise MethodNotImplementedError("method has not been implemented") - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: """ Call the Google LLM. Args: instruction (BasePrompt): Instruction to pass. - context (PipelineContext): Pass PipelineContext. + context (AgentState): Pass AgentState. Returns: str: LLM response. diff --git a/extensions/llms/huggingface/pandasai_huggingface/huggingface_text_gen.py b/extensions/llms/huggingface/pandasai_huggingface/huggingface_text_gen.py index 1ba20bd5c..a6dd3fbaf 100644 --- a/extensions/llms/huggingface/pandasai_huggingface/huggingface_text_gen.py +++ b/extensions/llms/huggingface/pandasai_huggingface/huggingface_text_gen.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from pandasai.chat.prompts.base import BasePrompt from pandasai.helpers import load_dotenv -from pandasai.prompts.base import BasePrompt + from pandasai.llm.base import LLM if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState load_dotenv() @@ -81,7 +83,7 @@ def _default_params(self) -> Dict[str, Any]: "seed": self.seed, } - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: prompt = instruction.to_string() memory = context.memory if context else None diff --git a/extensions/llms/ibm/pandasai_ibm/ibm_watsonx.py b/extensions/llms/ibm/pandasai_ibm/ibm_watsonx.py index d8b0e2a96..be6858dc2 100644 --- a/extensions/llms/ibm/pandasai_ibm/ibm_watsonx.py +++ b/extensions/llms/ibm/pandasai_ibm/ibm_watsonx.py @@ -3,14 +3,15 @@ import os from typing import TYPE_CHECKING, Optional +from pandasai.chat.code_execution.environment import import_dependency +from pandasai.chat.prompts.base import BasePrompt from pandasai.exceptions import APIKeyNotFoundError from pandasai.helpers import load_dotenv -from pandasai.helpers.optional import import_dependency -from pandasai.prompts.base import BasePrompt + from pandasai.llm.base import LLM if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState load_dotenv() @@ -131,7 +132,7 @@ def _set_params(self, **kwargs): f"Parameter {key} is invalid. Accepted parameters: {[*valid_params]}" ) - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: prompt = instruction.to_string() memory = context.memory if context else None diff --git a/extensions/llms/langchain/pandasai_langchain/langchain.py b/extensions/llms/langchain/pandasai_langchain/langchain.py index 4b1049227..067c77131 100644 --- a/extensions/llms/langchain/pandasai_langchain/langchain.py +++ b/extensions/llms/langchain/pandasai_langchain/langchain.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pandasai.chat.prompts.base import BasePrompt + try: from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import BaseChatModel @@ -12,12 +14,11 @@ from typing import TYPE_CHECKING -from pandasai.prompts.base import BasePrompt from pandasai.llm.base import LLM if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState """Langchain LLM @@ -39,13 +40,13 @@ class LangchainLLM(LLM): with LangChain. """ - langchain_llm: BaseLanguageModel + langchain_llm: BaseLanguageModel # type: ignore - def __init__(self, langchain_llm: BaseLanguageModel): + def __init__(self, langchain_llm: BaseLanguageModel): # type: ignore self.langchain_llm = langchain_llm def call( - self, instruction: BasePrompt, context: PipelineContext = None, suffix: str = "" + self, instruction: BasePrompt, context: AgentState = None, suffix: str = "" ) -> str: prompt = instruction.to_string() + suffix memory = context.memory if context else None diff --git a/extensions/llms/local/pandasai_local/local_llm.py b/extensions/llms/local/pandasai_local/local_llm.py index 7bac75d12..d8761159e 100644 --- a/extensions/llms/local/pandasai_local/local_llm.py +++ b/extensions/llms/local/pandasai_local/local_llm.py @@ -4,12 +4,13 @@ from openai import OpenAI +from pandasai.chat.prompts.base import BasePrompt from pandasai.helpers.memory import Memory -from pandasai.prompts.base import BasePrompt + from pandasai.llm.base import LLM if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState class LocalLLM(LLM): @@ -37,7 +38,7 @@ def chat_completion(self, value: str, memory: Memory) -> str: return response.choices[0].message.content - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: self.last_prompt = instruction.to_string() memory = context.memory if context else None diff --git a/extensions/llms/openai/pandasai_openai/base.py b/extensions/llms/openai/pandasai_openai/base.py index 5d08312e2..75afc74a9 100644 --- a/extensions/llms/openai/pandasai_openai/base.py +++ b/extensions/llms/openai/pandasai_openai/base.py @@ -2,12 +2,13 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Union +from pandasai.chat.prompts.base import BasePrompt from pandasai.helpers.memory import Memory -from pandasai.prompts.base import BasePrompt + from pandasai.llm.base import LLM if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState class BaseOpenAI(LLM): @@ -166,13 +167,13 @@ def chat_completion(self, value: str, memory: Memory) -> str: return response.choices[0].message.content - def call(self, instruction: BasePrompt, context: PipelineContext = None): + def call(self, instruction: BasePrompt, context: AgentState = None): """ Call the OpenAI LLM. Args: instruction (BasePrompt): A prompt object with instruction for LLM. - context (PipelineContext): context to pass. + context (AgentState): context to pass. Raises: UnsupportedModelError: Unsupported model diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 50884f2ed..dcfc8b8d9 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -1,4 +1,3 @@ -from .agent import Agent -from .base import BaseAgent +from .base import Agent -__all__ = ["Agent", "BaseAgent"] +__all__ = ["Agent"] diff --git a/pandasai/agent/agent.py b/pandasai/agent/agent.py deleted file mode 100644 index 8be3bb428..000000000 --- a/pandasai/agent/agent.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Type, Union - - -from pandasai.agent.base import BaseAgent -from pandasai.agent.base_judge import BaseJudge -from pandasai.agent.base_security import BaseSecurity -from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline -from pandasai.schemas.df_config import Config -from pandasai.vectorstores.vectorstore import VectorStore - -if TYPE_CHECKING: - from pandasai.dataframe import DataFrame - - -class Agent(BaseAgent): - def __init__( - self, - dfs: Union[DataFrame, List[DataFrame]], - config: Optional[Union[Config, dict]] = None, - memory_size: Optional[int] = 10, - pipeline: Optional[Type[GenerateChatPipeline]] = None, - vectorstore: Optional[VectorStore] = None, - description: str = None, - judge: BaseJudge = None, - security: BaseSecurity = None, - ): - super().__init__( - dfs, config, memory_size, vectorstore, description, security=security - ) - - self.pipeline = ( - pipeline( - self.context, - self.logger, - on_prompt_generation=self._callbacks.on_prompt_generation, - on_code_generation=self._callbacks.on_code_generation, - before_code_execution=self._callbacks.before_code_execution, - on_result=self._callbacks.on_result, - judge=judge, - ) - if pipeline - else GenerateChatPipeline( - self.context, - self.logger, - on_prompt_generation=self._callbacks.on_prompt_generation, - on_code_generation=self._callbacks.on_code_generation, - before_code_execution=self._callbacks.before_code_execution, - on_result=self._callbacks.on_result, - judge=judge, - ) - ) - - @property - def last_error(self): - return self.pipeline.last_error diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index 7094fc371..d3f86d0f4 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -1,17 +1,27 @@ import os -import re +import traceback import uuid -from typing import List, Optional, Union - -import pandas as pd -from pandasai.agent.base_security import BaseSecurity +from typing import Any, List, Optional, Tuple, Union + +from pandasai.chat.cache import Cache +from pandasai.chat.code_execution.code_executor import CodeExecutor +from pandasai.chat.code_generation.base import CodeGenerator +from pandasai.chat.prompts import ( + get_chat_prompt, + get_chat_prompt_for_sql, + get_correct_error_prompt, + get_correct_error_prompt_for_sql, + get_correct_output_type_error_prompt, +) +from pandasai.chat.response.base import ResponseParser +from pandasai.chat.user_query import UserQuery +from pandasai.dataframe.base import DataFrame +from pandasai.dataframe.virtual_dataframe import VirtualDataFrame +from .state import AgentState +from pandasai.chat.prompts.base import BasePrompt from pandasai.data_loader.schema_validator import is_schema_source_same from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.pipelines.chat.chat_pipeline_input import ChatPipelineInput -from pandasai.pipelines.chat.code_execution_pipeline_input import ( - CodeExecutionPipelineInput, -) from pandasai.vectorstores.vectorstore import VectorStore from ..config import load_config_from_json @@ -19,7 +29,6 @@ from ..exceptions import ( InvalidLLMOutputType, InvalidConfigError, - MaliciousQueryError, MissingVectorStoreError, ) from ..helpers.folder import Folder @@ -27,25 +36,23 @@ from ..helpers.memory import Memory from ..llm.base import LLM from importlib.util import find_spec -from ..pipelines.pipeline_context import PipelineContext -from ..prompts.base import BasePrompt -from ..schemas.df_config import Config -from .callbacks import Callbacks +from ..config import Config -class BaseAgent: +class Agent: """ Base Agent class to improve the conversational experience in PandasAI """ def __init__( self, - dfs: Union[pd.DataFrame, List[pd.DataFrame]], + dfs: Union[ + Union[DataFrame, VirtualDataFrame], List[Union[DataFrame, VirtualDataFrame]] + ], config: Optional[Union[Config, dict]] = None, memory_size: Optional[int] = 10, vectorstore: Optional[VectorStore] = None, description: str = None, - security: BaseSecurity = None, ): """ Args: @@ -53,40 +60,35 @@ def __init__( memory_size (int, optional): Conversation history to use during chat. Defaults to 1. """ - self.last_prompt = None - self.last_prompt_id = None - self.last_result = None - self.last_code_generated = None - self.last_code_executed = None + + self._state = AgentState() + self.agent_info = description self.conversation_id = uuid.uuid4() - self.dfs = dfs if isinstance(dfs, list) else [dfs] + # Instantiate dfs + self._state.dfs = dfs if isinstance(dfs, list) else [dfs] + + # Instantiate the config + self._state.config = self._get_config(config) - # Instantiate the context - self.config = self.get_config(config) + # Set llm in state + self._state.llm = self._get_llm(self._state.config.llm) # Validate df input with configurations - self.validate_input() + self._validate_input() # Initialize the context - self.context = PipelineContext( - dfs=self.dfs, - config=self.config, - memory=Memory(memory_size, agent_info=description), - vectorstore=vectorstore, - ) + self._state.memory = Memory(memory_size, agent_info=description) # Instantiate the logger - self.logger = Logger( - save_logs=self.config.save_logs, verbose=self.config.verbose + self._state.logger = Logger( + save_logs=self._state.config.save_logs, verbose=self._state.config.verbose ) - # Instantiate the vectorstore - self._vectorstore = vectorstore - - if self._vectorstore is None and os.environ.get("PANDASAI_API_KEY"): + # Initiate VectorStore + if vectorstore is None and os.environ.get("PANDASAI_API_KEY"): try: from pandasai.vectorstores.bamboo_vectorstore import BambooVectorStore except ImportError as e: @@ -94,102 +96,33 @@ def __init__( "Could not import BambooVectorStore. Please install the required dependencies." ) from e - self._vectorstore = BambooVectorStore(logger=self.logger) - self.context.vectorstore = self._vectorstore + self._state.vectorstore = BambooVectorStore(logger=self._state.logger) - self._callbacks = Callbacks(self) + # Initialize Cache + self._state.cache = Cache() if self._state.config.enable_cache else None - self.configure() + # Setup directory paths for cache and charts + self._configure() - self.pipeline = None - self.security = security + # Initialize Code Generator + self._code_generator = CodeGenerator(self._state) - def validate_input(self): - from pandasai.dataframe.virtual_dataframe import VirtualDataFrame + # Initialze Response Generator + self._response_parser = ResponseParser() - # Check if all DataFrames are VirtualDataFrame, and set direct_sql accordingly - all_virtual = all(isinstance(df, VirtualDataFrame) for df in self.dfs) - if all_virtual: - self.config.direct_sql = True - - # Validate the configurations based on direct_sql flag all have same source - if self.config.direct_sql and all_virtual: - base_schema_source = self.dfs[0].schema - for df in self.dfs[1:]: - # Ensure all DataFrames have the same source in direct_sql mode - - if not is_schema_source_same(base_schema_source, df.schema): - raise InvalidConfigError( - "Direct SQL requires all connectors to be of the same type, " - "belong to the same datasource, and have the same credentials." - ) - else: - # If not using direct_sql, ensure all DataFrames have the same source - if any(isinstance(df, VirtualDataFrame) for df in self.dfs): - base_schema_source = self.dfs[0].schema - for df in self.dfs[1:]: - if not is_schema_source_same(base_schema_source, df.schema): - raise InvalidConfigError( - "All DataFrames must belong to the same source." - ) - self.config.direct_sql = True - else: - # Means all are none virtual - self.config.direct_sql = False - - def configure(self): - # Add project root path if save_charts_path is default - if ( - self.config.save_charts - and self.config.save_charts_path == DEFAULT_CHART_DIRECTORY - ): - Folder.create(self.config.save_charts_path) - - # Add project root path if cache_path is default - if self.config.enable_cache: - Folder.create(DEFAULT_CACHE_DIRECTORY) - - def get_config(self, config: Union[Config, dict]): + def chat(self, query: str, output_type: Optional[str] = None): """ - Load a config to be used to run the queries. - - Args: - config (Union[Config, dict]): Config to be used + Start a new chat interaction with the assistant on Dataframe. """ + self.start_new_conversation() - config = load_config_from_json(config) - - if isinstance(config, dict) and config.get("llm") is not None: - config["llm"] = self.get_llm(config["llm"]) - - config = Config(**config) - - if config.llm is None: - config.llm = BambooLLM() - - return config + return self._process_query(query, output_type) - def get_llm(self, llm: LLM) -> LLM: + def follow_up(self, query: str, output_type: Optional[str] = None): """ - Load a LLM to be used to run the queries. - Check if it is a PandasAI LLM or a Langchain LLM. - If it is a Langchain LLM, wrap it in a PandasAI LLM. - - Args: - llm (object): LLMs option to be used for API access - - Raises: - BadImportError: If the LLM is a Langchain LLM but the langchain package - is not installed + Continue the existing chat interaction with the assistant on Dataframe. """ - # Check if pandasai_langchain is installed - if find_spec("pandasai_langchain") is not None: - from pandasai_langchain.langchain import LangchainLLM, is_langchain_llm - - if is_langchain_llm(llm): - llm = LangchainLLM(llm) - - return llm + return self._process_query(query, output_type) def call_llm_with_prompt(self, prompt: BasePrompt): """ @@ -198,141 +131,82 @@ def call_llm_with_prompt(self, prompt: BasePrompt): prompt (BasePrompt): BasePrompt to pass to LLM's """ retry_count = 0 - while retry_count < self.context.config.max_retries: + while retry_count < self._state.config.max_retries: try: - result: str = self.context.config.llm.call(prompt) + result: str = self._state.config.llm.call(prompt) if prompt.validate(result): return result else: raise InvalidLLMOutputType("Response validation failed!") except Exception: if ( - not self.context.config.use_error_correction_framework - or retry_count >= self.context.config.max_retries - 1 + not self._state.config.use_error_correction_framework + or retry_count >= self._state.config.max_retries - 1 ): raise retry_count += 1 - def check_malicious_keywords_in_query(self, query): - dangerous_pattern = re.compile( - r"\b(os|io|chr|b64decode)\b|" - r"(\.os|\.io|'os'|'io'|\"os\"|\"io\"|chr\(|chr\)|chr |\(chr)" - ) - return bool(dangerous_pattern.search(query)) + def generate_code( + self, query: Union[UserQuery, str] + ) -> Tuple[str, Optional[List[str]]]: + """Generate code using the LLM.""" - def chat(self, query: str, output_type: Optional[str] = None): - """ - Start a new chat interaction with the assistant on Dataframe. - """ - self.start_new_conversation() - return self._process_query(query, output_type) - - def follow_up(self, query: str, output_type: Optional[str] = None): - """ - Continue the existing chat interaction with the assistant on Dataframe. - """ - return self._process_query(query, output_type) - - def _process_query(self, query: str, output_type: Optional[str] = None): - """ - Process a query and return the result. - """ - if not self.pipeline: - return ( - "Unfortunately, I was not able to get your answers, " - "because of the following error: No pipeline exists" - ) - - try: - self.logger.log(f"Question: {query}") - self.logger.log( - f"Running PandasAI with {self.context.config.llm.type} LLM..." - ) - - self.assign_prompt_id() - - if self.check_malicious_keywords_in_query(query): - raise MaliciousQueryError( - "The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways." - ) - - if self.security and self.security.evaluate(query): - raise MaliciousQueryError("Query can result in a malicious code") - - pipeline_input = ChatPipelineInput( - query, output_type, self.conversation_id, self.last_prompt_id - ) - - return self.pipeline.run(pipeline_input) - - except Exception as exception: - return ( - "Unfortunately, I was not able to get your answers, " - "because of the following error:\n" - f"\n{exception}\n" - ) - - def generate_code(self, query: str, output_type: Optional[str] = None): - """ - Simulate code generation with the assistant on Dataframe. - """ - if not self.pipeline: - return ( - "Unfortunately, I was not able to get your answers, " - "because of the following error: No pipeline exists" - ) - try: - self.logger.log(f"Question: {query}") - self.logger.log( - f"Running PandasAI with {self.context.config.llm.type} LLM..." - ) - - self.assign_prompt_id() - - pipeline_input = ChatPipelineInput( - query, output_type, self.conversation_id, self.last_prompt_id - ) - - return self.pipeline.run_generate_code(pipeline_input) - except Exception as exception: - return ( - "Unfortunately, I was not able to get your answers, " - "because of the following error:\n" - f"\n{exception}\n" + self._state.memory.add(str(query), is_user=True) + if self._state.config.enable_cache: + cached_code = self._state.cache.get( + self._state.cache.get_cache_key(self._state) ) + if cached_code: + self._state.logger.log("Using cached code.") + return self._code_generator.validate_and_clean_code(cached_code) + + self._state.logger.log("Generating new code...") + prompt = ( + get_chat_prompt_for_sql(self._state) + if self._state.config.direct_sql + else get_chat_prompt(self._state) + ) + code, additional_dependencies = self._code_generator.generate_code(prompt) + self._state.last_prompt_used = prompt + return code, additional_dependencies def execute_code( - self, code: Optional[str] = None, output_type: Optional[str] = None - ): - """ - Execute code Generated with the assistant on Dataframe. - """ - if not self.pipeline: - return ( - "Unfortunately, I was not able to get your answers, " - "because of the following error: No pipeline exists to execute try Agent class" - ) - try: - if code is None: - code = self.last_code_generated - self.logger.log(f"Code: {code}") - self.logger.log( - f"Running PandasAI with {self.context.config.llm.type} LLM..." + self, code: str, additional_dependencies: Optional[List[str]] + ) -> dict: + """Execute the generated code.""" + self._state.logger.log(f"Executing code: {code}") + code_executor = CodeExecutor(additional_dependencies) + code_executor.add_to_env("dfs", self._state.dfs) + + if self._state.config.direct_sql: + code_executor.add_to_env( + "execute_sql_query", self._state.dfs[0].execute_sql_query ) - self.assign_prompt_id() + return code_executor.execute_and_return_result(code) - pipeline_input = CodeExecutionPipelineInput( - code, output_type, self.conversation_id, self.last_prompt_id - ) + def execute_with_retries( + self, code: str, additional_dependencies: Optional[List[str]] + ) -> Any: + """Execute the code with retry logic.""" + max_retries = self._state.config.max_retries + retries = 0 - return self.pipeline.run_execute_code(pipeline_input) - except Exception as exception: - return ( - "Unfortunately, I was not able to get your answers, " - "because of the following error:\n" - f"\n{exception}\n" - ) + while retries <= max_retries: + try: + result = self.execute_code(code, additional_dependencies) + return self._response_parser.parse(result) + except Exception as e: + retries += 1 + if retries > max_retries: + self._state.logger.log(f"Max retries reached. Error: {e}") + raise + self._state.logger.log( + f"Retrying execution ({retries}/{max_retries})..." + ) + code, additional_dependencies = self._regenerate_code_after_error( + code, e + ) def train( self, @@ -349,7 +223,7 @@ def train( Raises: ImportError: if default vector db lib is not installed it raises an error """ - if self._vectorstore is None: + if self._state.vectorstore is None: raise MissingVectorStoreError( "No vector store provided. Please provide a vector store to train the agent." ) @@ -360,18 +234,18 @@ def train( ) if docs is not None: - self._vectorstore.add_docs(docs) + self._state.vectorstore.add_docs(docs) if queries and codes: - self._vectorstore.add_question_answer(queries, codes) + self._state.vectorstore.add_question_answer(queries, codes) - self.logger.log("Agent successfully trained on the data") + self._state.logger.log("Agent successfully trained on the data") def clear_memory(self): """ Clears the memory """ - self.context.memory.clear() + self._state.memory.clear() self.conversation_id = uuid.uuid4() def add_message(self, message, is_user=False): @@ -380,15 +254,7 @@ def add_message(self, message, is_user=False): to the memory without calling the chat function (for example, when you need to add a message from the agent). """ - self.context.memory.add(message, is_user=is_user) - - def assign_prompt_id(self): - """Assign a prompt ID""" - - self.last_prompt_id = uuid.uuid4() - - if self.logger: - self.logger.log(f"Prompt ID: {self.last_prompt_id}") + self._state.memory.add(message, is_user=is_user) def start_new_conversation(self): """ @@ -396,10 +262,161 @@ def start_new_conversation(self): """ self.clear_memory() + def _validate_input(self): + from pandasai.dataframe.virtual_dataframe import VirtualDataFrame + + # Check if all DataFrames are VirtualDataFrame, and set direct_sql accordingly + all_virtual = all(isinstance(df, VirtualDataFrame) for df in self._state.dfs) + if all_virtual: + self._state.config.direct_sql = True + + # Validate the configurations based on direct_sql flag all have same source + if self._state.config.direct_sql and all_virtual: + base_schema_source = self._state.dfs[0].schema + for df in self._state.dfs[1:]: + # Ensure all DataFrames have the same source in direct_sql mode + + if not is_schema_source_same(base_schema_source, df.schema): + raise InvalidConfigError( + "Direct SQL requires all connectors to be of the same type, " + "belong to the same datasource, and have the same credentials." + ) + else: + # If not using direct_sql, ensure all DataFrames have the same source + if any(isinstance(df, VirtualDataFrame) for df in self._state.dfs): + base_schema_source = self._state.dfs[0].schema + for df in self._state.dfs[1:]: + if not is_schema_source_same(base_schema_source, df.schema): + raise InvalidConfigError( + "All DataFrames must belong to the same source." + ) + self._state.config.direct_sql = True + else: + # Means all are none virtual + self._state.config.direct_sql = False + + def _process_query(self, query: str, output_type: Optional[str] = None): + """Process a user query and return the result.""" + query = UserQuery(query) + self._state.logger.log(f"Question: {query}") + self._state.logger.log( + f"Running PandasAI with {self._state.config.llm.type} LLM..." + ) + + self._state.output_type = output_type + try: + self._assign_prompt_id() + + # Generate code + code, additional_dependencies = self.generate_code(query) + + # Execute code with retries + result = self.execute_with_retries(code, additional_dependencies) + + # Cache the result if caching is enabled + if self._state.config.enable_cache: + self._state.cache.set( + self._state.cache.get_cache_key(self._state), code + ) + + self._state.logger.log("Response Generated Successfully.") + # Generate and return the final response + return result + + except Exception as e: + return self._handle_exception(e) + + def _regenerate_code_after_error(self, code: str, error: Exception) -> str: + """Generate a new code snippet based on the error.""" + error_trace = traceback.format_exc() + self._state.logger.log(f"Execution failed with error: {error_trace}") + + if isinstance(error, InvalidLLMOutputType): + prompt = get_correct_output_type_error_prompt( + self._state, code, error_trace + ) + elif self._state.config.direct_sql: + prompt = get_correct_error_prompt_for_sql(self._state, code, error_trace) + else: + prompt = get_correct_error_prompt(self._state, code, error_trace) + + return self._code_generator.generate_code(prompt) + + def _configure(self): + # Add project root path if save_charts_path is default + if ( + self._state.config.save_charts + and self._state.config.save_charts_path == DEFAULT_CHART_DIRECTORY + ): + Folder.create(self._state.config.save_charts_path) + + # Add project root path if cache_path is default + if self._state.config.enable_cache: + Folder.create(DEFAULT_CACHE_DIRECTORY) + + def _get_config(self, config: Union[Config, dict]): + """ + Load a config to be used to run the queries. + + Args: + config (Union[Config, dict]): Config to be used + """ + + config = load_config_from_json(config) + return Config(**config) + + def _get_llm(self, llm: Optional[LLM] = None) -> LLM: + """ + Load a LLM to be used to run the queries. + Check if it is a PandasAI LLM or a Langchain LLM. + If it is a Langchain LLM, wrap it in a PandasAI LLM. + + Args: + llm (object): LLMs option to be used for API access + + Raises: + BadImportError: If the LLM is a Langchain LLM but the langchain package + is not installed + """ + + if llm is None: + return BambooLLM() + + # Check if pandasai_langchain is installed + if find_spec("pandasai_langchain") is not None: + from pandasai_langchain.langchain import LangchainLLM, is_langchain_llm + + if is_langchain_llm(llm): + llm = LangchainLLM(llm) + + return llm + + def _assign_prompt_id(self): + """Assign a prompt ID""" + + self._state.last_prompt_id = uuid.uuid4() + + if self._state.logger: + self._state.logger.log(f"Prompt ID: {self._state.last_prompt_id}") + + def _handle_exception(self, exception: Exception) -> str: + """Handle exceptions and return an error message.""" + error_message = traceback.format_exc() + self._state.logger.log(f"Processing failed with error: {error_message}") + return ( + "Unfortunately, I was not able to get your answers, " + "because of the following error:\n" + f"\n{exception}\n" + ) + + @property + def last_generated_code(self): + return self._state.last_code_generated + @property - def logs(self): - return self.logger.logs + def last_code_executed(self): + return self._state.last_code_generated @property - def last_error(self): - raise NotImplementedError + def last_prompt_used(self): + return self._state.last_prompt_used diff --git a/pandasai/agent/base_judge.py b/pandasai/agent/base_judge.py deleted file mode 100644 index d4c6b6136..000000000 --- a/pandasai/agent/base_judge.py +++ /dev/null @@ -1,18 +0,0 @@ -from pandasai.helpers.logger import Logger -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class BaseJudge: - context: PipelineContext - pipeline: Pipeline - logger: Logger - - def __init__( - self, - pipeline: Pipeline, - ) -> None: - self.pipeline = pipeline - - def evaluate(self, query: str, code: str) -> bool: - raise NotImplementedError diff --git a/pandasai/agent/base_security.py b/pandasai/agent/base_security.py deleted file mode 100644 index 29ddd67bb..000000000 --- a/pandasai/agent/base_security.py +++ /dev/null @@ -1,18 +0,0 @@ -from pandasai.helpers.logger import Logger -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class BaseSecurity: - context: PipelineContext - pipeline: Pipeline - logger: Logger - - def __init__( - self, - pipeline: Pipeline, - ) -> None: - self.pipeline = pipeline - - def evaluate(self, query: str) -> bool: - raise NotImplementedError diff --git a/pandasai/agent/callbacks.py b/pandasai/agent/callbacks.py deleted file mode 100644 index 8705c69de..000000000 --- a/pandasai/agent/callbacks.py +++ /dev/null @@ -1,42 +0,0 @@ -from ..prompts import BasePrompt - - -class Callbacks: - def __init__(self, agent): - self.agent = agent - - def on_prompt_generation(self, prompt: BasePrompt) -> str: - """ - A method to be called after prompt generation. - - Args: - prompt (str): A prompt - """ - self.agent.last_prompt = str(prompt) - - def on_code_generation(self, code: str): - """ - A method to be called after code generation. - - Args: - code (str): A python code - """ - self.agent.last_code_generated = code - - def before_code_execution(self, code: str): - """ - A method to be called after code execution. - - Args: - code (str): A python code - """ - self.agent.last_code_executed = code - - def on_result(self, result): - """ - A method to be called after code execution. - - Args: - result (Any): A python code - """ - self.agent.last_result = result diff --git a/pandasai/agent/state.py b/pandasai/agent/state.py new file mode 100644 index 000000000..f9bfefc5c --- /dev/null +++ b/pandasai/agent/state.py @@ -0,0 +1,60 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pandasai.helpers.cache import Cache +from pandasai.helpers.logger import Logger +from pandasai.helpers.memory import Memory + +from pandasai.config import Config +from pandasai.vectorstores.vectorstore import VectorStore + +if TYPE_CHECKING: + from pandasai.dataframe import DataFrame + from pandasai.dataframe import VirtualDataFrame + from pandasai.llm.base import LLM + + +@dataclass +class AgentState: + """ + Context class for managing pipeline attributes and passing them between steps. + """ + + dfs: List[Union[DataFrame, VirtualDataFrame]] = field(default_factory=list) + config: Union[Config, dict] = field(default_factory=dict) + memory: Memory = field(default_factory=Memory) + cache: Optional[Cache] = None + llm: LLM = None + vectorstore: Optional[VectorStore] = None + intermediate_values: Dict[str, Any] = field(default_factory=dict) + logger: Optional[Logger] = None + last_code_generated: Optional[str] = None + last_code_executed: Optional[str] = None + last_prompt_id: str = None + last_prompt_used: str = None + output_type: Optional[str] = None + + def __post_init__(self): + if isinstance(self.config, dict): + self.config = Config(**self.config) + + # Initialize cache only if enabled in config + if getattr(self.config, "enable_cache", False) and self.cache is None: + self.cache = Cache() + + def reset_intermediate_values(self): + """Resets the intermediate values dictionary.""" + self.intermediate_values.clear() + + def add(self, key: str, value: Any): + """Adds a single key-value pair to intermediate values.""" + self.intermediate_values[key] = value + + def add_many(self, values: Dict[str, Any]): + """Adds multiple key-value pairs to intermediate values.""" + self.intermediate_values.update(values) + + def get(self, key: str, default: Any = "") -> Any: + """Fetches a value from intermediate values or returns a default.""" + return self.intermediate_values.get(key, default) diff --git a/pandasai/chat/cache.py b/pandasai/chat/cache.py new file mode 100644 index 000000000..9acd6b29d --- /dev/null +++ b/pandasai/chat/cache.py @@ -0,0 +1,105 @@ +import glob +import os +from typing import Any + +try: + import duckdb +except ImportError: + duckdb = None + +from pandasai.constants import CACHE_TOKEN, DEFAULT_FILE_PERMISSIONS +from pandasai.helpers.path import find_project_root + + +class Cache: + """Cache class for caching queries. It is used to cache queries + to save time and money. + + Args: + filename (str): filename to store the cache. + """ + + def __init__(self, filename="cache_db_0.11", abs_path=None): + # Define cache directory and create directory if it does not exist + if abs_path: + cache_dir = abs_path + else: + try: + cache_dir = os.path.join(find_project_root(), "cache") + except ValueError: + cache_dir = os.path.join(os.getcwd(), "cache") + + os.makedirs(cache_dir, mode=DEFAULT_FILE_PERMISSIONS, exist_ok=True) + + self.filepath = os.path.join(cache_dir, f"{filename}.db") + self.connection = duckdb.connect(self.filepath) + self.connection.execute( + "CREATE TABLE IF NOT EXISTS cache (key STRING, value STRING)" + ) + + def versioned_key(self, key: str) -> str: + return f"{CACHE_TOKEN}-{key}" + + def set(self, key: str, value: str) -> None: + """Set a key value pair in the cache. + + Args: + key (str): key to store the value. + value (str): value to store in the cache. + """ + self.connection.execute( + "INSERT INTO cache VALUES (?, ?)", [self.versioned_key(key), value] + ) + + def get(self, key: str) -> str: + """Get a value from the cache. + + Args: + key (str): key to get the value from the cache. + + Returns: + str: value from the cache. + """ + result = self.connection.execute( + "SELECT value FROM cache WHERE key=?", [self.versioned_key(key)] + ) + return row[0] if (row := result.fetchone()) else None + + def delete(self, key: str) -> None: + """Delete a key value pair from the cache. + + Args: + key (str): key to delete the value from the cache. + """ + self.connection.execute( + "DELETE FROM cache WHERE key=?", [self.versioned_key(key)] + ) + + def close(self) -> None: + """Close the cache.""" + self.connection.close() + + def clear(self) -> None: + """Clean the cache.""" + self.connection.execute("DELETE FROM cache") + + def destroy(self) -> None: + """Destroy the cache.""" + self.connection.close() + for cache_file in glob.glob(f"{self.filepath}.*"): + os.remove(cache_file) + + def get_cache_key(self, context: Any) -> str: + """ + Return the cache key for the current conversation. + + Returns: + str: The cache key for the current conversation + """ + cache_key = context.memory.get_conversation() + + # make the cache key unique for each combination of dfs + for df in context.dfs: + cache_key += str(df.column_hash) + + return cache_key diff --git a/pandasai/chat/code_execution/__init__.py b/pandasai/chat/code_execution/__init__.py new file mode 100644 index 000000000..7f7786591 --- /dev/null +++ b/pandasai/chat/code_execution/__init__.py @@ -0,0 +1,3 @@ +from .code_executor import CodeExecutor + +__all__ = ["CodeExecutor"] diff --git a/pandasai/chat/code_execution/code_executor.py b/pandasai/chat/code_execution/code_executor.py new file mode 100644 index 000000000..b79345012 --- /dev/null +++ b/pandasai/chat/code_execution/code_executor.py @@ -0,0 +1,123 @@ +import ast +from pandasai.chat.code_execution.environment import get_environment +from pandasai.exceptions import NoResultFoundError + +from typing import Any, List + + +class CodeExecutor: + """ + Handle the logic on how to handle different lines of code + """ + + _environment: dict + + def __init__(self, additional_dependencies: List[dict] = []) -> None: + self._environment = get_environment(additional_dependencies) + self._plots = [] + + def add_to_env(self, key: str, value: Any) -> None: + """ + Expose extra variables in the code to be used + Args: + key (str): Name of variable or lib alias + value (Any): It can any value int, float, function, class etc. + """ + self._environment[key] = value + + def execute(self, code: str) -> dict: + exec(code, self._environment) + return self._environment + + def execute_and_return_result(self, code: str) -> Any: + """ + Executes the return updated environment + """ + exec(code, self._environment) + + # Get the result + if "result" not in self._environment: + var_name, subscript = self._get_variable_last_line_of_code(code) + if var_name and var_name in self._environment: + if subscript is not None: + result = self._environment[var_name][subscript] + else: + result = self._environment[var_name] + + raise NoResultFoundError("No result returned") + else: + result = self._environment["result"] + + if isinstance(result, dict) and result["type"] == "plot": + for plot in self._plots: + if plot["type"] == "plot": + result["value"] = plot["value"] + + return self._environment.get("result", None) + + def _get_variable_last_line_of_code(self, code: str) -> str: + """ + Returns variable name from the last line if it is a variable name or assigned. + Args: + code (str): Code in string. + + Returns: + str: Variable name. + """ + try: + tree = ast.parse(code) + last_statement = tree.body[-1] + + if isinstance(last_statement, ast.Assign): + return self._get_assign_variable(last_statement) + elif isinstance(last_statement, ast.Expr): + return self._get_expr_variable(last_statement) + + return ast.unparse(last_statement).strip() + + except SyntaxError: + return None + + def _get_assign_variable(self, assign_node): + """ + Extracts the variable name from an assignment node. + + Args: + assign_node (ast.Assign): Assignment node. + + Returns: + str: Variable name. + """ + if isinstance(assign_node.targets[0], ast.Subscript): + return self._get_subscript_variable(assign_node.targets[0]) + elif isinstance(assign_node.targets[0], ast.Name): + return assign_node.targets[0].id, None + + def _get_expr_variable(self, expr_node): + """ + Extracts the variable name from an expression node. + + Args: + expr_node (ast.Expr): Expression node. + + Returns: + str: Variable name. + """ + if isinstance(expr_node.value, ast.Subscript): + return self._get_subscript_variable(expr_node.value) + elif isinstance(expr_node.value, ast.Name): + return expr_node.value.id, None + + def _get_subscript_variable(self, subscript_node): + """ + Extracts the variable name from a subscript node. + + Args: + subscript_node (ast.Subscript): Subscript node. + + Returns: + str: Variable name. + """ + if isinstance(subscript_node.value, ast.Name): + variable_name = subscript_node.value.id + return variable_name, subscript_node.slice.value diff --git a/pandasai/helpers/optional.py b/pandasai/chat/code_execution/environment.py similarity index 81% rename from pandasai/helpers/optional.py rename to pandasai/chat/code_execution/environment.py index b54ec0f32..412c255ec 100644 --- a/pandasai/helpers/optional.py +++ b/pandasai/chat/code_execution/environment.py @@ -3,22 +3,26 @@ Source: Taken from pandas/compat/_optional.py """ -from __future__ import annotations - import importlib import sys import warnings -from typing import TYPE_CHECKING, List +from typing import List -import matplotlib.pyplot as plt -import numpy as np from pandas.util.version import Version -import pandas as pd +from .safe_libs.restricted_base64 import RestrictedBase64 +from .safe_libs.restricted_datetime import ( + RestrictedDatetime, +) +from .safe_libs.restricted_json import RestrictedJson +from .safe_libs.restricted_matplotlib import ( + RestrictedMatplotlib, +) +from .safe_libs.restricted_numpy import RestrictedNumpy +from .safe_libs.restricted_pandas import RestrictedPandas +from .safe_libs.restricted_seaborn import RestrictedSeaborn from pandasai.constants import WHITELISTED_BUILTINS - -if TYPE_CHECKING: - import types +import types # Minimum version required for each optional dependency @@ -48,10 +52,7 @@ def get_environment(additional_deps: List[dict]) -> dict: Returns (dict): A dictionary of environment variables """ - return { - "pd": pd, - "plt": plt, - "np": np, + env = { **{ lib["alias"]: ( getattr(import_dependency(lib["module"]), lib["name"]) @@ -67,6 +68,25 @@ def get_environment(additional_deps: List[dict]) -> dict: }, } + env["pd"] = RestrictedPandas() + env["plt"] = RestrictedMatplotlib() + env["np"] = RestrictedNumpy() + + for lib in additional_deps: + if lib["name"] == "seaborn": + env["sns"] = RestrictedSeaborn() + + if lib["name"] == "datetime": + env["datetime"] = RestrictedDatetime() + + if lib["name"] == "json": + env["json"] = RestrictedJson() + + if lib["name"] == "base64": + env["base64"] = RestrictedBase64() + + return env + def import_dependency( name: str, diff --git a/pandasai/chat/code_execution/safe_libs/base_restricted_module.py b/pandasai/chat/code_execution/safe_libs/base_restricted_module.py new file mode 100644 index 000000000..ce3bf21a5 --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/base_restricted_module.py @@ -0,0 +1,27 @@ +class BaseRestrictedModule: + def _wrap_function(self, func): + def wrapper(*args, **kwargs): + # Check for any suspicious arguments that might be used for importing + for arg in args + tuple(kwargs.values()): + if isinstance(arg, str): + # Check if the string is exactly one of the restricted modules + restricted_modules = ["io", "os", "subprocess", "sys", "importlib"] + if any(arg.lower() == module for module in restricted_modules): + raise SecurityError( + f"Potential security risk: '{arg}' is not allowed" + ) + return func(*args, **kwargs) + + return wrapper + + def _wrap_class(self, cls): + class WrappedClass(cls): + def __getattribute__(self, name): + attr = super().__getattribute__(name) + return self._wrap_function(self, attr) if callable(attr) else attr + + return WrappedClass + + +class SecurityError(Exception): + pass diff --git a/pandasai/chat/code_execution/safe_libs/restricted_base64.py b/pandasai/chat/code_execution/safe_libs/restricted_base64.py new file mode 100644 index 000000000..eb305885e --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_base64.py @@ -0,0 +1,21 @@ +import base64 + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedBase64(BaseRestrictedModule): + def __init__(self): + self.allowed_functions = [ + "b64encode", # Safe function to encode data into base64 + "b64decode", # Safe function to decode base64 encoded data + ] + + # Bind the allowed functions to the object + for func in self.allowed_functions: + if hasattr(base64, func): + setattr(self, func, self._wrap_function(getattr(base64, func))) + + def __getattr__(self, name): + if name not in self.allowed_functions: + raise AttributeError(f"'{name}' is not allowed in RestrictedBase64") + return getattr(base64, name) diff --git a/pandasai/chat/code_execution/safe_libs/restricted_datetime.py b/pandasai/chat/code_execution/safe_libs/restricted_datetime.py new file mode 100644 index 000000000..0fc48290a --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_datetime.py @@ -0,0 +1,64 @@ +import datetime + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedDatetime(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Classes + "date", + "time", + "datetime", + "timedelta", + "tzinfo", + "timezone", + # Constants + "MINYEAR", + "MAXYEAR", + # Time zone constants + "UTC", + # Functions + "now", + "utcnow", + "today", + "fromtimestamp", + "utcfromtimestamp", + "fromordinal", + "combine", + "strptime", + # Timedelta operations + "timedelta", + # Date operations + "weekday", + "isoweekday", + "isocalendar", + "isoformat", + "ctime", + "strftime", + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + # Time operations + "replace", + "tzname", + "dst", + "utcoffset", + # Comparison methods + "min", + "max", + ] + + for attr in self.allowed_attributes: + if hasattr(datetime, attr): + setattr(self, attr, self._wrap_function(getattr(datetime, attr))) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedDatetime") + + return getattr(datetime, name) diff --git a/pandasai/chat/code_execution/safe_libs/restricted_json.py b/pandasai/chat/code_execution/safe_libs/restricted_json.py new file mode 100644 index 000000000..7f13b6112 --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_json.py @@ -0,0 +1,23 @@ +import json + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedJson(BaseRestrictedModule): + def __init__(self): + self.allowed_functions = [ + "load", + "loads", + "dump", + "dumps", + ] + + # Bind the allowed functions to the object + for func in self.allowed_functions: + if hasattr(json, func): + setattr(self, func, self._wrap_function(getattr(json, func))) + + def __getattr__(self, name): + if name not in self.allowed_functions: + raise AttributeError(f"'{name}' is not allowed in RestrictedJson") + return getattr(json, name) diff --git a/pandasai/chat/code_execution/safe_libs/restricted_matplotlib.py b/pandasai/chat/code_execution/safe_libs/restricted_matplotlib.py new file mode 100644 index 000000000..c07b6b1e3 --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_matplotlib.py @@ -0,0 +1,82 @@ +import matplotlib.axes as axes +import matplotlib.figure as figure +import matplotlib.pyplot as plt + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedMatplotlib(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Figure and Axes creation + "figure", + "subplots", + "subplot", + # Plotting functions + "plot", + "scatter", + "bar", + "barh", + "hist", + "boxplot", + "violinplot", + "pie", + "errorbar", + "contour", + "contourf", + "imshow", + "pcolor", + "pcolormesh", + # Axis manipulation + "xlabel", + "ylabel", + "title", + "legend", + "xlim", + "ylim", + "axis", + "xticks", + "yticks", + "grid", + "axhline", + "axvline", + # Colorbar + "colorbar", + # Text and annotations + "text", + "annotate", + # Styling + "style", + # Save and show + "show", + "savefig", + # Color maps + "get_cmap", + # 3D plotting + "axes3d", + # Utility functions + "close", + "clf", + "cla", + # Constants + "rcParams", + "gca", + "invert_yaxis", + # Additional attributes needed + "set_title", # Allow setting title directly + "set_xlabel", + "set_ylabel", + ] + + for attr in self.allowed_attributes: + if hasattr(plt, attr): + setattr(self, attr, self._wrap_function(getattr(plt, attr))) + + # Special handling for figure and axes + self.Figure = self._wrap_class(figure.Figure) + self.Axes = self._wrap_class(axes.Axes) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedMatplotlib") + return getattr(plt, name) diff --git a/pandasai/chat/code_execution/safe_libs/restricted_numpy.py b/pandasai/chat/code_execution/safe_libs/restricted_numpy.py new file mode 100644 index 000000000..855fb70d6 --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_numpy.py @@ -0,0 +1,182 @@ +import numpy as np + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedNumpy(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Array creation + "array", + "zeros", + "ones", + "empty", + "full", + "zeros_like", + "ones_like", + "empty_like", + "full_like", + "eye", + "identity", + "diag", + "arange", + "linspace", + "logspace", + "geomspace", + "fromfunction", + "fromiter", + # Array manipulation + "reshape", + "ravel", + "flatten", + "moveaxis", + "rollaxis", + "swapaxes", + "transpose", + "split", + "hsplit", + "vsplit", + "dsplit", + "stack", + "column_stack", + "dstack", + "row_stack", + "concatenate", + "vstack", + "hstack", + "tile", + "repeat", + # Mathematical operations + "add", + "subtract", + "multiply", + "divide", + "power", + "mod", + "remainder", + "divmod", + "negative", + "positive", + "absolute", + "fabs", + "rint", + "floor", + "ceil", + "trunc", + "exp", + "expm1", + "exp2", + "log", + "log10", + "log2", + "log1p", + "sqrt", + "square", + "cbrt", + "reciprocal", + # Trigonometric functions + "sin", + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "arctan2", + "hypot", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + "deg2rad", + "rad2deg", + # Statistical functions + "mean", + "average", + "median", + "std", + "var", + "min", + "max", + "argmin", + "argmax", + "sum", + "prod", + "percentile", + "quantile", + "histogram", + "histogram2d", + "histogramdd", + "bincount", + "digitize", + # Linear algebra + "dot", + "vdot", + "inner", + "outer", + "matmul", + "tensordot", + "einsum", + "trace", + "diagonal", + # Sorting and searching + "sort", + "argsort", + "partition", + "argpartition", + "searchsorted", + "nonzero", + "where", + "extract", + # Logic functions + "all", + "any", + "greater", + "greater_equal", + "less", + "less_equal", + "equal", + "not_equal", + "logical_and", + "logical_or", + "logical_not", + "logical_xor", + "isfinite", + "isinf", + "isnan", + "isneginf", + "isposinf", + # Set operations + "unique", + "intersect1d", + "union1d", + "setdiff1d", + "setxor1d", + # Basic array information + "shape", + "size", + "ndim", + "dtype", + # Utility functions + "clip", + "round", + "sign", + "conj", + "real", + "imag", + "copy", + "asarray", + "asanyarray", + "ascontiguousarray", + "asfortranarray", + ] + + for attr in self.allowed_attributes: + if hasattr(np, attr): + setattr(self, attr, self._wrap_function(getattr(np, attr))) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedNumPy") + return getattr(np, name) diff --git a/pandasai/chat/code_execution/safe_libs/restricted_pandas.py b/pandasai/chat/code_execution/safe_libs/restricted_pandas.py new file mode 100644 index 000000000..75e5a083c --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_pandas.py @@ -0,0 +1,110 @@ +import pandas as pd + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedPandas(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # DataFrame creation and basic operations + "DataFrame", + "Series", + "concat", + "merge", + "join", + # Data manipulation + "groupby", + "pivot", + "pivot_table", + "melt", + "crosstab", + "cut", + "qcut", + "get_dummies", + "factorize", + # Indexing and selection + "loc", + "iloc", + "at", + "iat", + # Function application + "apply", + "applymap", + "pipe", + # Reshaping and sorting + "sort_values", + "sort_index", + "nlargest", + "nsmallest", + "rank", + "reindex", + "reset_index", + "set_index", + # Computations / descriptive stats + "sum", + "prod", + "min", + "max", + "mean", + "median", + "var", + "std", + "sem", + "skew", + "kurt", + "quantile", + "count", + "nunique", + "value_counts", + "describe", + "cov", + "corr", + # Date functionality + "to_datetime", + "date_range", + # String methods + "str", + # Categorical methods + "Categorical", + "cut", + "qcut", + # Plotting (if visualization is allowed) + "plot", + # Utility functions + "isnull", + "notnull", + "isna", + "notna", + "fillna", + "dropna", + "replace", + "astype", + "copy", + "drop_duplicates", + # Window functions + "rolling", + "expanding", + "ewm", + # Time series functionality + "resample", + "shift", + "diff", + "pct_change", + # Aggregation + "agg", + "aggregate", + ] + + for attr in self.allowed_attributes: + if hasattr(pd, attr): + setattr(self, attr, self._wrap_function(getattr(pd, attr))) + elif attr in ["loc", "iloc", "at", "iat"]: + # These are properties, not functions + setattr( + self, attr, property(lambda self, a=attr: getattr(pd.DataFrame, a)) + ) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedPandas") + return getattr(pd, name) diff --git a/pandasai/chat/code_execution/safe_libs/restricted_seaborn.py b/pandasai/chat/code_execution/safe_libs/restricted_seaborn.py new file mode 100644 index 000000000..a5ef4c6e8 --- /dev/null +++ b/pandasai/chat/code_execution/safe_libs/restricted_seaborn.py @@ -0,0 +1,74 @@ +import seaborn as sns + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedSeaborn(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Plot functions + "scatterplot", + "lineplot", + "relplot", + "displot", + "histplot", + "kdeplot", + "ecdfplot", + "rugplot", + "distplot", + "boxplot", + "violinplot", + "boxenplot", + "stripplot", + "swarmplot", + "barplot", + "countplot", + "heatmap", + "clustermap", + "regplot", + "lmplot", + "residplot", + "jointplot", + "pairplot", + "catplot", + # Axis styling + "set_style", + "set_context", + "set_palette", + "despine", + "move_legend", + "axes_style", + "plotting_context", + # Color palette functions + "color_palette", + "palplot", + "cubehelix_palette", + "light_palette", + "dark_palette", + "diverging_palette", + # Utility functions + "load_dataset", + # Figure-level interface + "FacetGrid", + "PairGrid", + "JointGrid", + # Regression and statistical estimation + "lmplot", + "regplot", + "residplot", + # Matrix plots + "heatmap", + "clustermap", + # Miscellaneous + "kdeplot", + "rugplot", + ] + + for attr in self.allowed_attributes: + if hasattr(sns, attr): + setattr(self, attr, self._wrap_function(getattr(sns, attr))) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedSeaborn") + return getattr(sns, name) diff --git a/pandasai/chat/code_generation/__init__.py b/pandasai/chat/code_generation/__init__.py new file mode 100644 index 000000000..8c62e09ca --- /dev/null +++ b/pandasai/chat/code_generation/__init__.py @@ -0,0 +1,11 @@ +from .code_cleaning import CodeCleaner +from .base import CodeGenerator +from .code_security import CodeSecurityChecker +from .code_validation import CodeRequirementValidator + +__all__ = [ + "CodeCleaner", + "CodeGenerator", + "CodeSecurityChecker", + "CodeRequirementValidator", +] diff --git a/pandasai/chat/code_generation/base.py b/pandasai/chat/code_generation/base.py new file mode 100644 index 000000000..17c184749 --- /dev/null +++ b/pandasai/chat/code_generation/base.py @@ -0,0 +1,66 @@ +import traceback + +from pandasai.agent.state import AgentState +from pandasai.chat.prompts.base import BasePrompt +from .code_cleaning import CodeCleaner +from .code_security import CodeSecurityChecker +from .code_validation import CodeRequirementValidator + + +class CodeGenerator: + def __init__(self, context: AgentState): + self._context = context + self._code_cleaner = CodeCleaner(self._context) + self._code_security = CodeSecurityChecker(self._context) + self._code_validator = CodeRequirementValidator(self._context) + + def generate_code(self, prompt: BasePrompt) -> tuple[str, list]: + """ + Generates code using a given LLM and performs validation and cleaning steps. + + Args: + context (PipelineContext): The pipeline context containing dataframes and logger. + prompt (BasePrompt): The prompt to guide code generation. + + Returns: + str: The final cleaned and validated code. + + Raises: + Exception: If any step fails during the process. + """ + try: + self._context.logger.log(f"Using Prompt: {prompt}") + + # Generate the code + code = self._context.config.llm.generate_code(prompt, self._context) + self._context.last_code_generated = code + self._context.logger.log(f"Code Generated:\n{code}") + + return self.validate_and_clean_code(code) + + except Exception as e: + error_message = f"An error occurred during code generation: {e}" + stack_trace = traceback.format_exc() + + self._context.logger.log( + error_message, + ) + self._context.logger.log(f"Stack Trace:\n{stack_trace}") + + raise e + + def validate_and_clean_code(self, code) -> tuple[str, list]: + # Check for malicious code + self._context.logger.log("Checking for malicious code...") + self._code_security.check(code) + self._context.logger.log("Malicious code check passed.") + + # Validate code requirements + self._context.logger.log("Validating code requirements...") + if not self._code_validator.validate(code): + raise ValueError("Code validation failed due to unmet requirements.") + self._context.logger.log("Code validation successful.") + + # Clean the code + self._context.logger.log("Cleaning the generated code...") + return self._code_cleaner.clean_code(code) diff --git a/pandasai/chat/code_generation/code_cleaning.py b/pandasai/chat/code_generation/code_cleaning.py new file mode 100644 index 000000000..a0526d9a1 --- /dev/null +++ b/pandasai/chat/code_generation/code_cleaning.py @@ -0,0 +1,318 @@ +import ast +import copy +import re +from typing import Union + +import astor +from pandasai.agent.state import AgentState + +from pandasai.chat.code_execution.code_executor import CodeExecutor +from pandasai.helpers.path import find_project_root +from pandasai.helpers.sql import extract_table_names + +from ...constants import WHITELISTED_LIBRARIES +from ...exceptions import BadImportError, MaliciousQueryError +from ...helpers.save_chart import add_save_chart + + +class CodeCleaner: + def __init__(self, context: AgentState): + """ + Initialize the CodeCleaner with the provided context. + + Args: + context (AgentState): The pipeline context for cleaning and validation. + """ + self.context = context + + def _check_imports( + self, node: Union[ast.Import, ast.ImportFrom] + ) -> Union[dict, None]: + """ + Add whitelisted imports to additional dependencies. + + Args: + node (Union[ast.Import, ast.ImportFrom]): AST node for import statements. + + Raises: + BadImportError: If the import is not whitelisted. + """ + module = node.names[0].name if isinstance(node, ast.Import) else node.module + library = module.split(".")[0] + + if library == "pandas": + return + + if ( + library + in WHITELISTED_LIBRARIES + + self.context.config.custom_whitelisted_dependencies + ): + for alias in node.names: + return { + "module": module, + "name": alias.name, + "alias": alias.asname or alias.name, + } + + raise BadImportError( + f"The library '{library}' is not in the list of whitelisted libraries. " + "To learn how to whitelist custom dependencies, visit: " + "https://docs.pandas-ai.com/custom-whitelisted-dependencies#custom-whitelisted-dependencies" + ) + + def _check_is_df_declaration(self, node: ast.AST) -> bool: + """ + Check if the node represents a pandas DataFrame declaration. + """ + value = node.value + return ( + isinstance(value, ast.Call) + and isinstance(value.func, ast.Attribute) + and isinstance(value.func.value, ast.Name) + and hasattr(value.func.value, "id") + and value.func.value.id == "pd" + and value.func.attr == "DataFrame" + ) + + def _get_target_names(self, targets): + """ + Extract target names from AST nodes. + """ + target_names = [] + is_slice = False + + for target in targets: + if isinstance(target, ast.Name) or ( + isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name) + ): + target_names.append( + target.id if isinstance(target, ast.Name) else target.value.id + ) + is_slice = isinstance(target, ast.Subscript) + + return target_names, is_slice, target + + def _check_direct_sql_func_def_exists(self, node: ast.AST) -> bool: + """ + Check if the node defines a direct SQL execution function. + """ + return ( + self.context.config.direct_sql + and isinstance(node, ast.FunctionDef) + and node.name == "execute_sql_query" + ) + + def _replace_table_names( + self, sql_query: str, table_names: list, allowed_table_names: list + ) -> str: + """ + Replace table names in the SQL query with case-sensitive or authorized table names. + """ + regex_patterns = { + table_name: re.compile(r"\b" + re.escape(table_name) + r"\b") + for table_name in table_names + } + for table_name in table_names: + if table_name in allowed_table_names: + quoted_table_name = allowed_table_names[table_name] + sql_query = regex_patterns[table_name].sub(quoted_table_name, sql_query) + else: + raise MaliciousQueryError( + f"Query uses unauthorized table: {table_name}." + ) + return sql_query + + def _clean_sql_query(self, sql_query: str) -> str: + """ + Clean the SQL query by trimming semicolons and validating table names. + """ + sql_query = sql_query.rstrip(";") + table_names = extract_table_names(sql_query) + allowed_table_names = {df.name: df.name for df in self.context.dfs} | { + f'"{df.name}"': df.name for df in self.context.dfs + } + return self._replace_table_names(sql_query, table_names, allowed_table_names) + + def _validate_and_make_table_name_case_sensitive(self, node: ast.AST) -> ast.AST: + """ + Validate table names and convert them to case-sensitive names in the SQL query. + """ + if isinstance(node, ast.Assign): + if ( + isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id in ["sql_query", "query"] + ): + sql_query = self._clean_sql_query(node.value.value) + node.value.value = sql_query + elif ( + isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "execute_sql_query" + and len(node.value.args) == 1 + and isinstance(node.value.args[0], ast.Constant) + and isinstance(node.value.args[0].value, str) + ): + sql_query = self._clean_sql_query(node.value.args[0].value) + node.value.args[0].value = sql_query + + if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call): + if ( + isinstance(node.value.func, ast.Name) + and node.value.func.id == "execute_sql_query" + and len(node.value.args) == 1 + and isinstance(node.value.args[0], ast.Constant) + and isinstance(node.value.args[0].value, str) + ): + sql_query = self._clean_sql_query(node.value.args[0].value) + node.value.args[0].value = sql_query + + return node + + def extract_fix_dataframe_redeclarations( + self, + node: ast.AST, + code_lines: list[str], + additional_deps: list[dict], + ) -> ast.AST: + """ + Checks if dataframe reclaration in the code like pd.DataFrame({...}) + Args: + node (ast.AST): Code Node + code_lines (list[str]): List of code str line by line + + Returns: + ast.AST: Updated Ast Node fixing redeclaration + """ + if isinstance(node, ast.Assign): + target_names, is_slice, target = self.get_target_names(node.targets) + + if target_names and self.check_is_df_declaration(node): + # Construct dataframe from node + code = "\n".join(code_lines) + code_executor = CodeExecutor(additional_deps) + code_executor.add_to_env("dfs", copy.deepcopy(self.context.dfs)) + env = code_executor.execute(code) + + df_generated = ( + env[target_names[0]][target.slice.value] + if is_slice + else env[target_names[0]] + ) + + # check if exists in provided dfs + for index, df in enumerate(self.context.dfs): + head = df.get_head() + if head.shape == df_generated.shape and head.columns.equals( + df_generated.columns + ): + target_var = ( + ast.Subscript( + value=ast.Name(id=target_names[0], ctx=ast.Load()), + slice=target.slice, + ctx=ast.Store(), + ) + if is_slice + else ast.Name(id=target_names[0], ctx=ast.Store()) + ) + return ast.Assign( + targets=[target_var], + value=ast.Subscript( + value=ast.Name(id="dfs", ctx=ast.Load()), + slice=ast.Index(value=ast.Num(n=index)), + ctx=ast.Load(), + ), + ) + return None + + def get_target_names(self, targets): + target_names = [] + is_slice = False + + for target in targets: + if isinstance(target, ast.Name) or ( + isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name) + ): + target_names.append( + target.id if isinstance(target, ast.Name) else target.value.id + ) + is_slice = isinstance(target, ast.Subscript) + + return target_names, is_slice, target + + def check_is_df_declaration(self, node: ast.AST): + value = node.value + return ( + isinstance(value, ast.Call) + and isinstance(value.func, ast.Attribute) + and isinstance(value.func.value, ast.Name) + and hasattr(value.func.value, "id") + and value.func.value.id == "pd" + and value.func.attr == "DataFrame" + ) + + def clean_code(self, code: str) -> tuple[str, list]: + """ + Clean the provided code by validating imports, handling SQL queries, and processing charts. + + Args: + code (str): The code to clean. + + Returns: + tuple: Cleaned code as a string and a list of additional dependencies. + """ + code = self._handle_charts(code) + additional_dependencies = [] + clean_code_lines = [] + + tree = ast.parse(code) + new_body = [] + + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + imported_lib = self._check_imports(node) + if imported_lib: + additional_dependencies.append(imported_lib) + continue + + if self._check_direct_sql_func_def_exists(node): + continue + + if self.context.config.direct_sql: + node = self._validate_and_make_table_name_case_sensitive(node) + + clean_code_lines.append(astor.to_source(node)) + + new_body.append( + self.extract_fix_dataframe_redeclarations( + node, clean_code_lines, additional_dependencies + ) + or node + ) + + new_tree = ast.Module(body=new_body) + return ( + astor.to_source(new_tree, pretty_source=lambda x: "".join(x)).strip(), + additional_dependencies, + ) + + def _handle_charts(self, code: str) -> str: + """ + Handle chart-related code modifications. + """ + code = re.sub(r"""(['"])([^'"]*\.png)\1""", r"\1temp_chart.png\1", code) + if self.context.config.save_charts: + return add_save_chart( + code, + logger=self.context.logger, + file_name=str(self.context.prompt_id), + save_charts_path_str=self.context.config.save_charts_path, + ) + return add_save_chart( + code, + logger=self.context.logger, + file_name="temp_chart", + save_charts_path_str=f"{find_project_root()}/exports/charts", + ) diff --git a/pandasai/chat/code_generation/code_security.py b/pandasai/chat/code_generation/code_security.py new file mode 100644 index 000000000..09860eaa0 --- /dev/null +++ b/pandasai/chat/code_generation/code_security.py @@ -0,0 +1,162 @@ +import ast +import re +import astor +from pandasai.agent.state import AgentState +from pandasai.constants import RESTRICTED_LIBS +from pandasai.exceptions import MaliciousCodeGenerated + + +class CodeSecurityChecker: + """ + A class to perform checks for malicious and unsafe code execution. + """ + + def __init__(self, context: AgentState): + """ + Initialize the CodeSecurityChecker with the provided context. + + Args: + context (AgentState): The pipeline context for the code checks. + """ + self.context = context + self.dangerous_modules = [ + " os", + " io", + ".os", + ".io", + "'os'", + "'io'", + '"os"', + '"io"', + "chr(", + "chr)", + "chr ", + "(chr", + "b64decode", + ] + self.dangerous_builtins = ["__subclasses__", "__builtins__", "__import__"] + self.unsafe_methods = [ + ".to_csv", + ".to_excel", + ".to_json", + ".to_sql", + ".to_feather", + ".to_hdf", + ".to_parquet", + ".to_pickle", + ".to_gbq", + ".to_stata", + ".to_records", + ".to_latex", + ".to_html", + ".to_markdown", + ".to_clipboard", + ] + + def _is_malicious_code(self, code: str) -> bool: + """ + Check if the provided code contains malicious content. + + Args: + code (str): The code to be checked. + + Returns: + bool: True if malicious code is found, otherwise False. + """ + tree = ast.parse(code) + + def check_restricted_access(node): + """Check if the node accesses restricted modules or private attributes.""" + if isinstance(node, ast.Attribute): + attr_chain = [] + while isinstance(node, ast.Attribute): + if node.attr.startswith("_"): + raise MaliciousCodeGenerated( + f"Access to private attribute '{node.attr}' is not allowed." + ) + attr_chain.insert(0, node.attr) + node = node.value + if isinstance(node, ast.Name): + attr_chain.insert(0, node.id) + if any(module in RESTRICTED_LIBS for module in attr_chain): + raise MaliciousCodeGenerated( + f"Restricted access detected in attribute chain: {'.'.join(attr_chain)}" + ) + elif isinstance(node, ast.Subscript) and isinstance( + node.value, ast.Attribute + ): + check_restricted_access(node.value) + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + sub_module_names = alias.name.split(".") + if any(module in RESTRICTED_LIBS for module in sub_module_names): + raise MaliciousCodeGenerated( + f"Restricted library import detected: {alias.name}" + ) + elif isinstance(node, ast.ImportFrom): + sub_module_names = node.module.split(".") + if any(module in RESTRICTED_LIBS for module in sub_module_names): + raise MaliciousCodeGenerated( + f"Restricted library import detected: {node.module}" + ) + if any(alias.name in RESTRICTED_LIBS for alias in node.names): + raise MaliciousCodeGenerated( + "Restricted library import detected in 'from ... import ...'" + ) + elif isinstance(node, (ast.Attribute, ast.Subscript)): + check_restricted_access(node) + + return any( + re.search(r"\b" + re.escape(module) + r"\b", code) + for module in self.dangerous_modules + ) + + def _is_jailbreak(self, node: ast.stmt) -> bool: + """ + Check if the code node contains jailbreak methods. + + Args: + node (ast.stmt): A code node to be checked. + + Returns: + bool: True if jailbreak methods are found, otherwise False. + """ + node_str = ast.dump(node) + return any(builtin in node_str for builtin in self.dangerous_builtins) + + def _is_unsafe(self, node: ast.stmt) -> bool: + """ + Check if the code node contains unsafe operations. + + Args: + node (ast.stmt): A code node to be checked. + + Returns: + bool: True if unsafe operations are found, otherwise False. + """ + code = astor.to_source(node) + return any(method in code for method in self.unsafe_methods) + + def check(self, code: str) -> None: + """ + Perform all checks on the provided code. + + Args: + code (str): The code to be checked. + + Raises: + MaliciousCodeGenerated: If malicious or unsafe code is detected. + """ + if self._is_malicious_code(code): + raise MaliciousCodeGenerated("Malicious code is generated!") + + tree = ast.parse(code) + for node in tree.body: + if self._is_jailbreak(node): + raise MaliciousCodeGenerated("Restricted builtins are used!") + if self._is_unsafe(node): + raise MaliciousCodeGenerated( + "The code is unsafe and can lead to I/O operations or other malicious operations that are not permitted!" + ) diff --git a/pandasai/chat/code_generation/code_validation.py b/pandasai/chat/code_generation/code_validation.py new file mode 100644 index 000000000..1ddf1a060 --- /dev/null +++ b/pandasai/chat/code_generation/code_validation.py @@ -0,0 +1,71 @@ +import ast + +from pandasai.agent.state import AgentState +from pandasai.exceptions import ExecuteSQLQueryNotUsed + + +class CodeRequirementValidator: + """ + Class to validate code requirements based on a pipeline context. + """ + + class _FunctionCallVisitor(ast.NodeVisitor): + """ + AST visitor to collect all function calls in a given Python code. + """ + + def __init__(self): + self.function_calls = [] + + def visit_Call(self, node: ast.Call): + """ + Visits a function call and records its name or attribute. + """ + if isinstance(node.func, ast.Name): + self.function_calls.append(node.func.id) + elif isinstance(node.func, ast.Attribute) and isinstance( + node.func.value, ast.Name + ): + self.function_calls.append(f"{node.func.value.id}.{node.func.attr}") + self.generic_visit(node) # Continue visiting child nodes + + def __init__(self, context: AgentState): + """ + Initialize the validator with the pipeline context. + + Args: + context (AgentState): The agent state containing the configuration. + """ + self.context = context + + def validate(self, code: str) -> bool: + """ + Validates whether the code meets the requirements specified by the pipeline context. + + Args: + code (str): The code to validate. + + Returns: + bool: True if the code meets the requirements, False otherwise. + + Raises: + ExecuteSQLQueryNotUsed: If the `direct_sql` configuration is enabled and + `execute_sql_query` is not used in the code. + """ + # Parse the code into an AST + tree = ast.parse(code) + + # Use the visitor to collect function calls + func_call_visitor = self._FunctionCallVisitor() + func_call_visitor.visit(tree) + + # Validate requirements + if ( + self.context.config.direct_sql + and "execute_sql_query" not in func_call_visitor.function_calls + ): + raise ExecuteSQLQueryNotUsed( + "The code must execute SQL queries using the `execute_sql_query` function, which is already defined!" + ) + + return True diff --git a/pandasai/chat/prompts/__init__.py b/pandasai/chat/prompts/__init__.py new file mode 100644 index 000000000..d3e0a1a13 --- /dev/null +++ b/pandasai/chat/prompts/__init__.py @@ -0,0 +1,78 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +from pandasai.chat.prompts.correct_execute_sql_query_usage_error_prompt import ( + CorrectExecuteSQLQueryUsageErrorPrompt, +) +from pandasai.chat.prompts.correct_output_type_error_prompt import ( + CorrectOutputTypeErrorPrompt, +) +from .generate_python_code_with_sql import GeneratePythonCodeWithSQLPrompt +from .base import BasePrompt +from .correct_error_prompt import CorrectErrorPrompt +from .generate_python_code import GeneratePythonCodePrompt + +if TYPE_CHECKING: + from pandasai.agent.state import AgentState + + +def get_chat_prompt(context: AgentState) -> BasePrompt: + viz_lib = "matplotlib" + if context.config.data_viz_library: + viz_lib = context.config.data_viz_library + + return GeneratePythonCodePrompt( + context=context, + last_code_generated=context.get("last_code_generated"), + viz_lib=viz_lib, + output_type=context.output_type, + ) + + +def get_chat_prompt_for_sql(context: AgentState) -> BasePrompt: + viz_lib = "matplotlib" + if context.config.data_viz_library: + viz_lib = context.config.data_viz_library + + return GeneratePythonCodeWithSQLPrompt( + context=context, + last_code_generated=context.get("last_code_generated"), + viz_lib=viz_lib, + output_type=context.output_type, + ) + + +def get_correct_error_prompt( + context: AgentState, code: str, traceback_error: str +) -> BasePrompt: + return CorrectErrorPrompt( + context=context, + code=code, + error=traceback_error, + ) + + +def get_correct_error_prompt_for_sql( + context: AgentState, code: str, traceback_error: str +) -> BasePrompt: + return CorrectExecuteSQLQueryUsageErrorPrompt( + context=context, code=code, error=traceback_error + ) + + +def get_correct_output_type_error_prompt( + context: AgentState, code: str, traceback_error: str +) -> BasePrompt: + return CorrectOutputTypeErrorPrompt( + context=context, + code=code, + error=traceback_error, + output_type=context.output_type, + ) + + +__all__ = [ + "BasePrompt", + "CorrectErrorPrompt", + "GeneratePythonCodePrompt", + "GeneratePythonCodeWithSQLPrompt", +] diff --git a/pandasai/prompts/base.py b/pandasai/chat/prompts/base.py similarity index 100% rename from pandasai/prompts/base.py rename to pandasai/chat/prompts/base.py diff --git a/pandasai/prompts/correct_error_prompt.py b/pandasai/chat/prompts/correct_error_prompt.py similarity index 100% rename from pandasai/prompts/correct_error_prompt.py rename to pandasai/chat/prompts/correct_error_prompt.py diff --git a/pandasai/prompts/correct_execute_sql_query_usage_error_prompt.py b/pandasai/chat/prompts/correct_execute_sql_query_usage_error_prompt.py similarity index 94% rename from pandasai/prompts/correct_execute_sql_query_usage_error_prompt.py rename to pandasai/chat/prompts/correct_execute_sql_query_usage_error_prompt.py index 44e496ee9..e373fbd17 100644 --- a/pandasai/prompts/correct_execute_sql_query_usage_error_prompt.py +++ b/pandasai/chat/prompts/correct_execute_sql_query_usage_error_prompt.py @@ -1,4 +1,4 @@ -from pandasai.prompts.base import BasePrompt +from pandasai.chat.prompts.base import BasePrompt class CorrectExecuteSQLQueryUsageErrorPrompt(BasePrompt): diff --git a/pandasai/prompts/correct_output_type_error_prompt.py b/pandasai/chat/prompts/correct_output_type_error_prompt.py similarity index 100% rename from pandasai/prompts/correct_output_type_error_prompt.py rename to pandasai/chat/prompts/correct_output_type_error_prompt.py diff --git a/pandasai/prompts/file_based_prompt.py b/pandasai/chat/prompts/file_based_prompt.py similarity index 100% rename from pandasai/prompts/file_based_prompt.py rename to pandasai/chat/prompts/file_based_prompt.py diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/chat/prompts/generate_python_code.py similarity index 100% rename from pandasai/prompts/generate_python_code.py rename to pandasai/chat/prompts/generate_python_code.py diff --git a/pandasai/prompts/generate_python_code_with_sql.py b/pandasai/chat/prompts/generate_python_code_with_sql.py similarity index 70% rename from pandasai/prompts/generate_python_code_with_sql.py rename to pandasai/chat/prompts/generate_python_code_with_sql.py index 4250c49a2..1f39aa962 100644 --- a/pandasai/prompts/generate_python_code_with_sql.py +++ b/pandasai/chat/prompts/generate_python_code_with_sql.py @@ -1,4 +1,4 @@ -from pandasai.prompts.generate_python_code import GeneratePythonCodePrompt +from pandasai.chat.prompts.generate_python_code import GeneratePythonCodePrompt class GeneratePythonCodeWithSQLPrompt(GeneratePythonCodePrompt): diff --git a/pandasai/prompts/generate_system_message.py b/pandasai/chat/prompts/generate_system_message.py similarity index 100% rename from pandasai/prompts/generate_system_message.py rename to pandasai/chat/prompts/generate_system_message.py diff --git a/pandasai/prompts/templates/correct_error_prompt.tmpl b/pandasai/chat/prompts/templates/correct_error_prompt.tmpl similarity index 100% rename from pandasai/prompts/templates/correct_error_prompt.tmpl rename to pandasai/chat/prompts/templates/correct_error_prompt.tmpl diff --git a/pandasai/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl b/pandasai/chat/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl similarity index 100% rename from pandasai/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl rename to pandasai/chat/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl diff --git a/pandasai/prompts/templates/correct_output_type_error_prompt.tmpl b/pandasai/chat/prompts/templates/correct_output_type_error_prompt.tmpl similarity index 100% rename from pandasai/prompts/templates/correct_output_type_error_prompt.tmpl rename to pandasai/chat/prompts/templates/correct_output_type_error_prompt.tmpl diff --git a/pandasai/prompts/templates/generate_python_code.tmpl b/pandasai/chat/prompts/templates/generate_python_code.tmpl similarity index 100% rename from pandasai/prompts/templates/generate_python_code.tmpl rename to pandasai/chat/prompts/templates/generate_python_code.tmpl diff --git a/pandasai/prompts/templates/generate_python_code_with_sql.tmpl b/pandasai/chat/prompts/templates/generate_python_code_with_sql.tmpl similarity index 100% rename from pandasai/prompts/templates/generate_python_code_with_sql.tmpl rename to pandasai/chat/prompts/templates/generate_python_code_with_sql.tmpl diff --git a/pandasai/prompts/templates/generate_system_message.tmpl b/pandasai/chat/prompts/templates/generate_system_message.tmpl similarity index 100% rename from pandasai/prompts/templates/generate_system_message.tmpl rename to pandasai/chat/prompts/templates/generate_system_message.tmpl diff --git a/pandasai/chat/prompts/templates/shared/dataframe.tmpl b/pandasai/chat/prompts/templates/shared/dataframe.tmpl new file mode 100644 index 000000000..6892571d7 --- /dev/null +++ b/pandasai/chat/prompts/templates/shared/dataframe.tmpl @@ -0,0 +1 @@ +{{ df.serialize_dataframe(index-1, context.config.direct_sql, context.config.enforce_privacy) }} diff --git a/pandasai/prompts/templates/shared/output_type_template.tmpl b/pandasai/chat/prompts/templates/shared/output_type_template.tmpl similarity index 100% rename from pandasai/prompts/templates/shared/output_type_template.tmpl rename to pandasai/chat/prompts/templates/shared/output_type_template.tmpl diff --git a/pandasai/prompts/templates/shared/vectordb_docs.tmpl b/pandasai/chat/prompts/templates/shared/vectordb_docs.tmpl similarity index 100% rename from pandasai/prompts/templates/shared/vectordb_docs.tmpl rename to pandasai/chat/prompts/templates/shared/vectordb_docs.tmpl diff --git a/pandasai/chat/response/__init__.py b/pandasai/chat/response/__init__.py new file mode 100644 index 000000000..9209eb94e --- /dev/null +++ b/pandasai/chat/response/__init__.py @@ -0,0 +1,4 @@ +from .base import ResponseParser +from .response_types import Chart, DataFrame, Number, String + +__all__ = ["ResponseParser", "Chart", "DataFrame", "Number", "String"] diff --git a/pandasai/chat/response/base.py b/pandasai/chat/response/base.py new file mode 100644 index 000000000..76a26491b --- /dev/null +++ b/pandasai/chat/response/base.py @@ -0,0 +1,66 @@ +import re +import numpy as np +import pandas as pd +from .response_types import Chart, DataFrame, Number, String +from pandasai.exceptions import InvalidOutputValueMismatch + + +class ResponseParser: + def parse(self, result: dict): + self._validate_response(result) + return self._generate_response(result) + + def _generate_response(self, result: dict): + if result["type"] == "number": + return Number(result) + elif result["type"] == "string": + return String(result) + elif result["type"] == "dataframe": + return DataFrame(result) + elif result["type"] == "plot": + return Chart(result) + + def _validate_response(self, result: dict): + if ( + not isinstance(result, dict) + or "type" not in result + or "value" not in result + ): + raise InvalidOutputValueMismatch( + 'Result must be in the format of dictionary of type and value like `result = {"type": ..., "value": ... }`' + ) + elif result["type"] == "number": + if not isinstance(result["value"], (int, float, np.int64)): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a numeric value for result type 'number', but received a non-numeric value." + ) + elif result["type"] == "string": + if not isinstance(result["value"], str): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a string value for result type 'string', but received a non-string value." + ) + elif result["type"] == "dataframe": + if not isinstance(result["value"], (pd.DataFrame, pd.Series, dict)): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a Pandas DataFrame or Series, but received an incompatible type." + ) + + elif result["type"] == "plot": + if not isinstance(result["value"], (str, dict)): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a plot save path str but received an incompatible type." + ) + + if isinstance(result["value"], dict) or ( + isinstance(result["value"], str) + and "data:image/png;base64" in result["value"] + ): + return True + + path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$" + if not bool(re.match(path_to_plot_pattern, result["value"])): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a plot save path str but received an incompatible type." + ) + + return True diff --git a/pandasai/chat/response/response_types.py b/pandasai/chat/response/response_types.py new file mode 100644 index 000000000..4d2a79aa8 --- /dev/null +++ b/pandasai/chat/response/response_types.py @@ -0,0 +1,117 @@ +import base64 +import json +from typing import Any + +import pandas as pd +from PIL import Image + + +class Base: + """ + Base class for different types of response values. + """ + + def __init__(self, result: dict): + if not isinstance(result, dict): + raise ValueError( + "Expected a dictionary result, but got {type(result).__name__}." + ) + self.result = result + self.value = result.get("value") + + def __str__(self) -> str: + """Return the string representation of the response.""" + return str(self.value) + + def __repr__(self) -> str: + """Return a detailed string representation for debugging.""" + return f"{self.__class__.__name__}(type={self.result.get('type')!r}, value={self.value!r})" + + def to_dict(self) -> dict: + """Return a dictionary representation.""" + return self.result + + def to_json(self) -> str: + """Return a JSON representation.""" + return json.dumps(self.to_dict()) + + def get_value(self) -> Any: + """Return the value from the result.""" + return self.value + + +class String(Base): + """ + Class for handling string responses. + """ + + def __init__(self, result: dict): + super().__init__(result) + + +class Number(Base): + """ + Class for handling numerical responses. + """ + + def __init__(self, result: dict): + super().__init__(result) + + +class DataFrame(Base): + """ + Class for handling DataFrame responses. + """ + + def __init__(self, result: dict): + result["value"] = self.format_value(result["value"]) + super().__init__(result) + + def format_value(self, value): + if isinstance(value, dict): + return pd.Dataframe(value) + return value + + def to_csv(self, file_path: str) -> None: + self.value.to_csv(file_path, index=False) + + def to_excel(self, file_path: str) -> None: + self.value.to_excel(file_path, index=False) + + def head(self, n: int = 5) -> pd.DataFrame: + return self.value.head(n) + + def tail(self, n: int = 5) -> pd.DataFrame: + return self.value.tail(n) + + def to_json(self): + json_data = json.loads(self.value.to_json(orient="split", date_format="iso")) + self.result["value"] = { + "headers": json_data["columns"], + "rows": json_data["data"], + } + return self.result + + def to_dict(self) -> dict: + self.result["value"] = self.value.to_dict(orient="split", date_format="iso") + return self.result + + +class Chart(Base): + def __init__(self, result: dict): + super().__init__(result) + + def show(self): + img = Image.open(self.value) + img.show() + + def to_dict(self): + with open(self.value["value"], "rb") as image_file: + image_data = image_file.read() + + # Encode the image data to Base64 + self.result[ + "value" + ] = f"data:image/png;base64,{base64.b64encode(image_data).decode()}" + + return self.result diff --git a/pandasai/chat/user_query.py b/pandasai/chat/user_query.py new file mode 100644 index 000000000..20825f9ae --- /dev/null +++ b/pandasai/chat/user_query.py @@ -0,0 +1,30 @@ +import re +from pandasai.exceptions import MaliciousQueryError + + +class UserQuery: + def __init__(self, user_query: str): + self._check_malicious_keywords_in_query(user_query) + self.value = user_query + + def __str__(self): + return self.value + + def __repr__(self): + return f"UserQuery(value={self._value})" + + def _check_malicious_keywords_in_query(self, user_query): + dangerous_pattern = re.compile( + r"\b(os|io|chr|b64decode)\b|" + r"(\.os|\.io|'os'|'io'|\"os\"|\"io\"|chr\(|chr\)|chr |\(chr)" + ) + if bool(dangerous_pattern.search(user_query)): + raise MaliciousQueryError( + "The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways." + ) + + def __dict__(self): + return self.value + + def to_json(self): + return self.value diff --git a/pandasai/config.py b/pandasai/config.py index 6126819ca..c7edd5e18 100644 --- a/pandasai/config.py +++ b/pandasai/config.py @@ -1,9 +1,36 @@ import json -from typing import Optional, Union -from . import llm +import pandasai.llm as llm +from pandasai.llm.base import LLM + from .helpers.path import find_closest -from .schemas.df_config import Config + +from typing import Any, List, Optional, Dict, Union +from pydantic import BaseModel, Field, ConfigDict + +from pandasai.constants import DEFAULT_CHART_DIRECTORY + + +class Config(BaseModel): + save_logs: bool = True + verbose: bool = False + enforce_privacy: bool = False + enable_cache: bool = True + use_error_correction_framework: bool = True + save_charts: bool = False + save_charts_path: str = DEFAULT_CHART_DIRECTORY + custom_whitelisted_dependencies: List[str] = Field(default_factory=list) + max_retries: int = 3 + + llm: Optional[LLM] = None + data_viz_library: Optional[str] = None + direct_sql: bool = False + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @classmethod + def from_dict(cls, config: Dict[str, Any]) -> "Config": + return cls(**config) def load_config_from_json( diff --git a/pandasai/constants.py b/pandasai/constants.py index edb9ee5c5..86cffed94 100644 --- a/pandasai/constants.py +++ b/pandasai/constants.py @@ -90,10 +90,31 @@ "numpy", "datetime", "json", - "io", "base64", ] +# List of restricted libs +RESTRICTED_LIBS = [ + "os", # OS-level operations (file handling, environment variables) + "sys", # System-level access + "subprocess", # Run system commands + "shutil", # File operations, including delete + "multiprocessing", # Spawn new processes + "threading", # Thread-level operations + "socket", # Network connections + "http", # HTTP requests + "ftplib", # FTP connections + "paramiko", # SSH operations + "tempfile", # Create temporary files + "pathlib", # Filesystem path handling + "resource", # Access resource usage limits (system-related) + "ssl", # SSL socket connections + "pickle", # Unsafe object serialization + "ctypes", # C-level interaction with memory + "psutil", # System and process utilities + "io", # System io operations +] + PANDASBI_SETUP_MESSAGE = ( "The api_key client option must be set either by passing api_key to the client " "or by setting the PANDASAI_API_KEY environment variable. To get the key follow below steps:\n" diff --git a/pandasai/dataframe/__init__.py b/pandasai/dataframe/__init__.py index 36288a4ae..3fb3e08ef 100644 --- a/pandasai/dataframe/__init__.py +++ b/pandasai/dataframe/__init__.py @@ -1,5 +1,5 @@ from .base import DataFrame +from .virtual_dataframe import VirtualDataFrame -__all__ = [ - "DataFrame", -] + +__all__ = ["DataFrame", "VirtualDataFrame"] diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index cbe6a7a8f..bfcd3d46a 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -1,7 +1,8 @@ +from __future__ import annotations import pandas as pd -from typing import Optional, Union, Dict, Any, ClassVar -from pandasai.agent.agent import Agent -from pandasai.schemas.df_config import Config +from typing import TYPE_CHECKING, Optional, Union, Dict, Any, ClassVar + +from pandasai.config import Config import hashlib from pandasai.helpers.dataframe_serializer import ( DataframeSerializer, @@ -9,6 +10,10 @@ ) +if TYPE_CHECKING: + from pandasai.agent.base import Agent + + class DataFrame(pd.DataFrame): """ PandasAI DataFrame that extends pandas DataFrame with natural language capabilities. @@ -43,7 +48,7 @@ def __init__(self, *args, **kwargs): self._agent: Optional[Agent] = None self._column_hash = self._calculate_column_hash() - def _validate_schema(self, schema: Dict) -> None: + def _validate_schema(self, schema: Optional[Dict]) -> None: """Validates the provided schema format.""" if not isinstance(schema, dict): raise ValueError("Schema must be a dictionary") @@ -81,7 +86,7 @@ def chat(self, prompt: str, config: Optional[Union[dict, Config]] = None) -> str self.config = Config(**config) if isinstance(config, dict) else config if self._agent is None: - from pandasai.agent.agent import ( + from pandasai.agent import ( Agent, ) # Import here to avoid circular import @@ -124,7 +129,6 @@ def serialize_dataframe( self, index: int, is_direct_sql: bool, - serializer_type: DataframeSerializerType, enforce_privacy: bool, ) -> str: """ @@ -148,7 +152,7 @@ def serialize_dataframe( "is_direct_sql": is_direct_sql, "enforce_privacy": enforce_privacy, }, - type_=serializer_type, + type_=DataframeSerializerType.CSV, ) def get_head(self): diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 968db66d5..745098fae 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -256,3 +256,11 @@ class InvalidDataSourceType(Exception): """Raised error with invalid data source provided""" pass + + +class MaliciousCodeGenerated(Exception): + """ + Raise error if malicious code is generated + Args: + Exception (Exception): MaliciousCodeGenerated + """ diff --git a/pandasai/helpers/request.py b/pandasai/helpers/request.py index e038c6d6d..fffd4a77a 100644 --- a/pandasai/helpers/request.py +++ b/pandasai/helpers/request.py @@ -1,6 +1,7 @@ import logging import os import traceback +from typing import Optional from urllib.parse import urljoin import requests @@ -15,7 +16,10 @@ class Session: _logger: Logger def __init__( - self, endpoint_url: str = None, api_key: str = None, logger: Logger = None + self, + endpoint_url: Optional[str] = None, + api_key: Optional[str] = None, + logger: Optional[Logger] = None, ) -> None: if api_key is None: api_key = os.environ.get("PANDASAI_API_KEY") or None diff --git a/pandasai/llm/bamboo_llm.py b/pandasai/llm/bamboo_llm.py index 40e32d642..c7cddc6df 100644 --- a/pandasai/llm/bamboo_llm.py +++ b/pandasai/llm/bamboo_llm.py @@ -1,7 +1,11 @@ -from typing import Optional +from __future__ import annotations +from typing import TYPE_CHECKING, Optional + + +if TYPE_CHECKING: + from pandasai.chat.prompts.base import BasePrompt from ..helpers.request import Session -from ..prompts.base import BasePrompt from .base import LLM diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index b05aaf7ea..df63c2b27 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -5,18 +5,18 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any, Optional +from pandasai.chat.prompts.base import BasePrompt +from pandasai.chat.prompts.generate_system_message import GenerateSystemMessagePrompt from pandasai.helpers.memory import Memory -from pandasai.prompts.generate_system_message import GenerateSystemMessagePrompt from ..exceptions import ( APIKeyNotFoundError, MethodNotImplementedError, NoCodeFoundError, ) -from ..prompts.base import BasePrompt if TYPE_CHECKING: - from pandasai.pipelines.pipeline_context import PipelineContext + from pandasai.agent.state import AgentState class LLM: @@ -135,13 +135,13 @@ def get_messages(self, memory: Memory) -> Any: return memory.get_previous_conversation() @abstractmethod - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: """ Execute the LLM with given prompt. Args: instruction (BasePrompt): A prompt object with instruction for LLM. - context (PipelineContext, optional): PipelineContext. Defaults to None. + context (AgentState, optional): AgentState. Defaults to None. Raises: MethodNotImplementedError: Call method has not been implemented @@ -149,7 +149,7 @@ def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: """ raise MethodNotImplementedError("Call method has not been implemented") - def generate_code(self, instruction: BasePrompt, context: PipelineContext) -> str: + def generate_code(self, instruction: BasePrompt, context: AgentState) -> str: """ Generate the code based on the instruction and the given prompt. diff --git a/pandasai/llm/fake.py b/pandasai/llm/fake.py index 1b23597ed..84cf47321 100644 --- a/pandasai/llm/fake.py +++ b/pandasai/llm/fake.py @@ -2,9 +2,9 @@ from typing import Optional -from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.agent.state import AgentState +from pandasai.chat.prompts.base import BasePrompt -from ..prompts.base import BasePrompt from .base import LLM @@ -22,7 +22,7 @@ def __init__(self, output: Optional[str] = None, type: str = "fake"): self.last_prompt = None self.response = "Mocked response" - def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + def call(self, instruction: BasePrompt, context: AgentState = None) -> str: self.called = True self.last_prompt = instruction.to_string() return self.response diff --git a/pandasai/pipelines/__init__.py b/pandasai/pipelines/__init__.py deleted file mode 100644 index 0e748a585..000000000 --- a/pandasai/pipelines/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .abstract_pipeline import AbstractPipeline -from .base_logic_unit import BaseLogicUnit -from .pipeline import Pipeline - -__all__ = ["Pipeline", "AbstractPipeline", "BaseLogicUnit", "GenerateSDFPipeline"] diff --git a/pandasai/pipelines/abstract_pipeline.py b/pandasai/pipelines/abstract_pipeline.py deleted file mode 100644 index 2f3e3d388..000000000 --- a/pandasai/pipelines/abstract_pipeline.py +++ /dev/null @@ -1,14 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class AbstractPipeline(ABC): - def __init__(self) -> None: - pass - - @abstractmethod - def run(self, input: Any) -> Any: - """ - This method will return output according to - Implementation.""" - raise NotImplementedError("Run method must be implemented") diff --git a/pandasai/pipelines/base_logic_unit.py b/pandasai/pipelines/base_logic_unit.py deleted file mode 100644 index 93705c07a..000000000 --- a/pandasai/pipelines/base_logic_unit.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - - -class BaseLogicUnit(ABC): - """ - Logic units for pipeline each logic unit should be inherited from this Logic unit - """ - - def __init__(self, skip_if=None, on_execution=None, before_execution=None): - super().__init__() - self.skip_if = skip_if - self.on_execution = on_execution - self.before_execution = before_execution - - @abstractmethod - def execute(self, input: Any, **kwargs) -> LogicUnitOutput: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - raise NotImplementedError("execute method is not implemented.") diff --git a/pandasai/pipelines/chat/cache_lookup.py b/pandasai/pipelines/chat/cache_lookup.py deleted file mode 100644 index 2cf7cba0f..000000000 --- a/pandasai/pipelines/chat/cache_lookup.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Any - -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - -from ...helpers.logger import Logger -from ..base_logic_unit import BaseLogicUnit -from ..pipeline_context import PipelineContext - - -class CacheLookup(BaseLogicUnit): - """ - Cache Lookup of Code Stage - """ - - pass - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - if ( - pipeline_context.config.enable_cache - and pipeline_context.cache - and pipeline_context.cache.get( - pipeline_context.cache.get_cache_key(pipeline_context) - ) - ): - logger.log("Using cached response") - - code = pipeline_context.cache.get( - pipeline_context.cache.get_cache_key(pipeline_context) - ) - - pipeline_context.add("found_in_cache", True) - - return LogicUnitOutput(code, True, "Cache Hit") diff --git a/pandasai/pipelines/chat/cache_population.py b/pandasai/pipelines/chat/cache_population.py deleted file mode 100644 index 6393eddf1..000000000 --- a/pandasai/pipelines/chat/cache_population.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Any - -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - -from ..base_logic_unit import BaseLogicUnit -from ..pipeline_context import PipelineContext - - -class CachePopulation(BaseLogicUnit): - """ - Cache Population Stage - """ - - pass - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - - code = input - - if pipeline_context.config.enable_cache and pipeline_context.cache: - pipeline_context.cache.set( - pipeline_context.cache.get_cache_key(pipeline_context), code - ) - - return LogicUnitOutput( - code, - True, - "Prompt Cached Successfully" - if pipeline_context.config.enable_cache - else "Caching disabled", - ) diff --git a/pandasai/pipelines/chat/chat_pipeline_input.py b/pandasai/pipelines/chat/chat_pipeline_input.py deleted file mode 100644 index aa9541b75..000000000 --- a/pandasai/pipelines/chat/chat_pipeline_input.py +++ /dev/null @@ -1,27 +0,0 @@ -import uuid -from dataclasses import dataclass - - -@dataclass -class ChatPipelineInput: - """ - Contain all the data needed by the chat pipeline - """ - - query: str - output_type: str - instance: str - conversation_id: uuid.UUID - prompt_id: uuid.UUID - - def __init__( - self, - query: str, - output_type: str, - conversation_id: uuid.UUID, - prompt_id: uuid.UUID, - ) -> None: - self.query = query - self.output_type = output_type - self.conversation_id = conversation_id - self.prompt_id = prompt_id diff --git a/pandasai/pipelines/chat/code_execution_pipeline_input.py b/pandasai/pipelines/chat/code_execution_pipeline_input.py deleted file mode 100644 index d369b3d61..000000000 --- a/pandasai/pipelines/chat/code_execution_pipeline_input.py +++ /dev/null @@ -1,29 +0,0 @@ -import uuid -from dataclasses import dataclass - - -@dataclass -class CodeExecutionPipelineInput: - """ - Contain all the data needed by the Code Execution pipeline - """ - - code: str - output_type: str - instance: str - conversation_id: uuid.UUID - prompt_id: uuid.UUID - query: str - - def __init__( - self, - code: str, - output_type: str, - conversation_id: uuid.UUID, - prompt_id: uuid.UUID, - ) -> None: - self.code = code - self.output_type = output_type - self.conversation_id = conversation_id - self.prompt_id = prompt_id - self.query = "" diff --git a/pandasai/pipelines/chat/code_generator.py b/pandasai/pipelines/chat/code_generator.py deleted file mode 100644 index 5f82c6a6d..000000000 --- a/pandasai/pipelines/chat/code_generator.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Any - -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - -from ...helpers.logger import Logger -from ..base_logic_unit import BaseLogicUnit -from ..pipeline_context import PipelineContext - - -class CodeGenerator(BaseLogicUnit): - """ - LLM Code Generation Stage - """ - - pass - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - - code = pipeline_context.config.llm.generate_code(input, pipeline_context) - - pipeline_context.add("last_code_generated", code) - logger.log( - f"""Prompt used: - {pipeline_context.config.llm.last_prompt} - """ - ) - logger.log( - f"""Code generated: - ``` - {code} - ``` - """ - ) - - return LogicUnitOutput( - code, - True, - "Code Generated Successfully", - {"content_type": "code", "value": code}, - ) diff --git a/pandasai/pipelines/chat/error_correction_pipeline/error_correction_pipeline.py b/pandasai/pipelines/chat/error_correction_pipeline/error_correction_pipeline.py deleted file mode 100644 index a8e7db557..000000000 --- a/pandasai/pipelines/chat/error_correction_pipeline/error_correction_pipeline.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -from pandasai.helpers.logger import Logger -from pandasai.pipelines.chat.code_cleaning import CodeCleaning -from pandasai.pipelines.chat.code_generator import CodeGenerator -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.chat.error_correction_pipeline.error_prompt_generation import ( - ErrorPromptGeneration, -) -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class ErrorCorrectionPipeline: - """ - Error Correction Pipeline to regenerate prompt and code - """ - - _context: PipelineContext - _logger: Logger - - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - on_prompt_generation=None, - on_code_generation=None, - ): - self.pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - ErrorPromptGeneration(on_prompt_generation=on_prompt_generation), - CodeGenerator(on_execution=on_code_generation), - CodeCleaning(), - ], - ) - self._context = context - self._logger = logger - - def run(self, input: ErrorCorrectionPipelineInput): - self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") - return self.pipeline.run(input) diff --git a/pandasai/pipelines/chat/error_correction_pipeline/error_correction_pipeline_input.py b/pandasai/pipelines/chat/error_correction_pipeline/error_correction_pipeline_input.py deleted file mode 100644 index 6914f2ce3..000000000 --- a/pandasai/pipelines/chat/error_correction_pipeline/error_correction_pipeline_input.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class ErrorCorrectionPipelineInput: - code: str - exception: Exception - - def __init__(self, code: str, exception: Exception): - self.code = code - self.exception = exception diff --git a/pandasai/pipelines/chat/error_correction_pipeline/error_prompt_generation.py b/pandasai/pipelines/chat/error_correction_pipeline/error_prompt_generation.py deleted file mode 100644 index 4826cf78c..000000000 --- a/pandasai/pipelines/chat/error_correction_pipeline/error_prompt_generation.py +++ /dev/null @@ -1,98 +0,0 @@ -import traceback -from typing import Any, Callable - -from pandasai.exceptions import ExecuteSQLQueryNotUsed, InvalidLLMOutputType -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext -from pandasai.prompts.base import BasePrompt -from pandasai.prompts.correct_error_prompt import CorrectErrorPrompt -from pandasai.prompts.correct_execute_sql_query_usage_error_prompt import ( - CorrectExecuteSQLQueryUsageErrorPrompt, -) -from pandasai.prompts.correct_output_type_error_prompt import ( - CorrectOutputTypeErrorPrompt, -) - - -class ErrorPromptGeneration(BaseLogicUnit): - on_prompt_generation: Callable[[str], None] - - def __init__( - self, - on_prompt_generation=None, - skip_if=None, - on_execution=None, - before_execution=None, - ): - self.on_prompt_generation = on_prompt_generation - super().__init__(skip_if, on_execution, before_execution) - - def execute(self, input: ErrorCorrectionPipelineInput, **kwargs) -> Any: - """ - A method to retry the code execution with error correction framework. - - Args: - code (str): A python code - context (PipelineContext) : Pipeline Context - logger (Logger) : Logger - e (Exception): An exception - dataframes - - Returns (str): A python code - """ - self.context: PipelineContext = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - e = input.exception - - prompt = self.get_prompt(e, input.code) - if self.on_prompt_generation: - self.on_prompt_generation(prompt) - - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - { - "content_type": "prompt", - "value": prompt.to_string(), - }, - ) - - def get_prompt(self, e: Exception, code: str) -> BasePrompt: - """ - Return a prompt by key. - - Args: - values (dict): The values to use for the prompt - - Returns: - BasePrompt: The prompt - """ - traceback_errors = traceback.format_exc() - return ( - CorrectOutputTypeErrorPrompt( - context=self.context, - code=code, - error=traceback_errors, - output_type=self.context.get("output_type"), - ) - if isinstance(e, InvalidLLMOutputType) - else ( - CorrectExecuteSQLQueryUsageErrorPrompt( - context=self.context, code=code, error=traceback_errors - ) - if isinstance(e, ExecuteSQLQueryNotUsed) - else CorrectErrorPrompt( - context=self.context, - code=code, - error=traceback_errors, - ) - ) - ) diff --git a/pandasai/pipelines/chat/generate_chat_pipeline.py b/pandasai/pipelines/chat/generate_chat_pipeline.py deleted file mode 100644 index b1cf3470a..000000000 --- a/pandasai/pipelines/chat/generate_chat_pipeline.py +++ /dev/null @@ -1,261 +0,0 @@ -from typing import Optional - -from pandasai.agent.base_judge import BaseJudge -from pandasai.pipelines.chat.chat_pipeline_input import ( - ChatPipelineInput, -) -from pandasai.pipelines.chat.code_execution_pipeline_input import ( - CodeExecutionPipelineInput, -) -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline import ( - ErrorCorrectionPipeline, -) -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.chat.validate_pipeline_input import ( - ValidatePipelineInput, -) - -from ...helpers.logger import Logger -from ..pipeline import Pipeline -from ..pipeline_context import PipelineContext -from .cache_lookup import CacheLookup -from .cache_population import CachePopulation -from .code_cleaning import CodeCleaning -from .code_execution import CodeExecution -from .code_generator import CodeGenerator -from .prompt_generation import PromptGeneration -from .result_parsing import ResultParsing -from .result_validation import ResultValidation - - -class GenerateChatPipeline: - code_generation_pipeline = Pipeline - code_execution_pipeline = Pipeline - context: PipelineContext - _logger: Logger - last_error: str - - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - judge: BaseJudge = None, - on_prompt_generation=None, - on_code_generation=None, - before_code_execution=None, - on_result=None, - ): - self.code_generation_pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - ValidatePipelineInput(), - CacheLookup(), - PromptGeneration( - skip_if=self.is_cached, - on_execution=on_prompt_generation, - ), - CodeGenerator( - skip_if=self.is_cached, - on_execution=on_code_generation, - ), - CachePopulation(skip_if=self.is_cached), - CodeCleaning( - skip_if=self.no_code, - on_retry=self.on_code_retry, - ), - ], - ) - - self.code_execution_pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - CodeExecution( - before_execution=before_code_execution, - on_retry=self.on_code_retry, - ), - ResultValidation(), - ResultParsing( - before_execution=on_result, - ), - ], - ) - - self.code_exec_error_pipeline = ErrorCorrectionPipeline( - context=context, - logger=logger, - on_code_generation=on_code_generation, - on_prompt_generation=on_prompt_generation, - ) - - self.judge = judge - - if self.judge: - if self.judge.pipeline.pipeline.context: - self.judge.pipeline.pipeline.context.memory = context.memory - else: - self.judge.pipeline.pipeline.context = context - - self.judge.pipeline.pipeline.logger = logger - - self.context = context - self._logger = logger - self.last_error = None - - def on_code_retry(self, code: str, exception: Exception): - correction_input = ErrorCorrectionPipelineInput(code, exception) - return self.code_exec_error_pipeline.run(correction_input) - - def no_code(self, context: PipelineContext): - return context.get("last_code_generated") is None - - def is_cached(self, context: PipelineContext): - return context.get("found_in_cache") - - def run_generate_code(self, input: ChatPipelineInput) -> dict: - """ - Executes the code generation pipeline with user input and return the result - Args: - input (ChatPipelineInput): _description_ - - Returns: - The `output` dictionary is expected to have the following keys: - - 'type': The type of the output. - - 'value': The value of the output. - """ - self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") - - # Reset intermediate values - self.context.reset_intermediate_values() - - # Add Query to memory - self.context.memory.add(input.query, True) - - self.context.add_many( - { - "output_type": input.output_type, - "last_prompt_id": input.prompt_id, - } - ) - try: - return self.code_generation_pipeline.run(input) - - except Exception as e: - # Show the full traceback - import traceback - - traceback.print_exc() - - self.last_error = str(e) - - return ( - "Unfortunately, I was not able to answer your question, " - "because of the following error:\n" - f"\n{e}\n" - ) - - def run_execute_code(self, input: CodeExecutionPipelineInput) -> dict: - """ - Executes the chat pipeline with user input and return the result - Args: - input (CodeExecutionPipelineInput): _description_ - - Returns: - The `output` dictionary is expected to have the following keys: - - 'type': The type of the output. - - 'value': The value of the output. - """ - self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") - - # Reset intermediate values - self.context.reset_intermediate_values() - - # Add Query to memory - self.context.memory.add(input.code, True) - - self.context.add_many( - { - "output_type": input.output_type, - "last_prompt_id": input.prompt_id, - } - ) - try: - return self.code_execution_pipeline.run(input.code) - - except Exception as e: - # Show the full traceback - import traceback - - traceback.print_exc() - - self.last_error = str(e) - - return ( - "Unfortunately, I was not able to answer your question, " - "because of the following error:\n" - f"\n{e}\n" - ) - - def run(self, input: ChatPipelineInput) -> dict: - """ - Executes the chat pipeline with user input and return the result - Args: - input (ChatPipelineInput): _description_ - - Returns: - The `output` dictionary is expected to have the following keys: - - 'type': The type of the output. - - 'value': The value of the output. - """ - self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") - - # Reset intermediate values - self.context.reset_intermediate_values() - - # Add Query to memory - self.context.memory.add(input.query, True) - - self.context.add_many( - { - "output_type": input.output_type, - "last_prompt_id": input.prompt_id, - } - ) - try: - if self.judge: - code = self.code_generation_pipeline.run(input) - - retry_count = 0 - while retry_count < self.context.config.max_retries: - if self.judge.evaluate(query=input.query, code=code): - break - code = self.code_generation_pipeline.run(input) - retry_count += 1 - - output = self.code_execution_pipeline.run(code) - - elif self.code_execution_pipeline: - output = ( - self.code_generation_pipeline | self.code_execution_pipeline - ).run(input) - else: - output = self.code_generation_pipeline.run(input) - - return output - - except Exception as e: - # Show the full traceback - import traceback - - traceback.print_exc() - - self.last_error = str(e) - - return ( - "Unfortunately, I was not able to answer your question, " - "because of the following error:\n" - f"\n{e}\n" - ) diff --git a/pandasai/pipelines/chat/prompt_generation.py b/pandasai/pipelines/chat/prompt_generation.py deleted file mode 100644 index 6b48b05eb..000000000 --- a/pandasai/pipelines/chat/prompt_generation.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Any, Union - -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - -from ...helpers.logger import Logger -from ...prompts.base import BasePrompt -from ...prompts.generate_python_code import GeneratePythonCodePrompt -from ...prompts.generate_python_code_with_sql import GeneratePythonCodeWithSQLPrompt -from ..base_logic_unit import BaseLogicUnit -from ..pipeline_context import PipelineContext - - -class PromptGeneration(BaseLogicUnit): - """ - Code Prompt Generation Stage - """ - - pass - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - self.context: PipelineContext = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - - prompt = self.get_chat_prompt(self.context) - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - {"content_type": "prompt", "value": prompt.to_string()}, - ) - - def get_chat_prompt(self, context: PipelineContext) -> Union[str, BasePrompt]: - # set matplotlib as the default library - viz_lib = "matplotlib" - if context.config.data_viz_library: - viz_lib = context.config.data_viz_library - - output_type = context.get("output_type") - - return ( - GeneratePythonCodeWithSQLPrompt( - context=context, - last_code_generated=context.get("last_code_generated"), - viz_lib=viz_lib, - output_type=output_type, - ) - if context.config.direct_sql - else GeneratePythonCodePrompt( - context=context, - last_code_generated=context.get("last_code_generated"), - viz_lib=viz_lib, - output_type=output_type, - ) - ) diff --git a/pandasai/pipelines/chat/result_parsing.py b/pandasai/pipelines/chat/result_parsing.py deleted file mode 100644 index a955f86e0..000000000 --- a/pandasai/pipelines/chat/result_parsing.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any - -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - -from ...responses.context import Context -from ...responses.response_parser import ResponseParser -from ..base_logic_unit import BaseLogicUnit -from ..pipeline_context import PipelineContext - - -class ResultParsing(BaseLogicUnit): - - """ - Result Parsing Stage - """ - - pass - - def response_parser(self, context: PipelineContext, logger) -> ResponseParser: - context = Context(context.config, logger=logger) - return ( - context.config.response_parser(context) - if context.config.response_parser - else ResponseParser(context) - ) - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - - result = input - - self._add_result_to_memory(result=result, context=pipeline_context) - - parser = self.response_parser(pipeline_context, logger=kwargs.get("logger")) - result = parser.parse(result) - return LogicUnitOutput(result, True, "Results parsed successfully") - - def _add_result_to_memory(self, result: dict, context: PipelineContext): - """ - Add the result to the memory. - - Args: - result (dict): The result to add to the memory - context (PipelineContext) : Pipeline Context - """ - if result is None: - return - - if result["type"] in ["string", "number"]: - context.memory.add(str(result["value"]), False) - elif result["type"] == "dataframe": - context.memory.add("Check it out: ", False) - elif result["type"] == "plot": - context.memory.add("Check it out: ", False) diff --git a/pandasai/pipelines/chat/result_validation.py b/pandasai/pipelines/chat/result_validation.py deleted file mode 100644 index 1154ca5f3..000000000 --- a/pandasai/pipelines/chat/result_validation.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -from typing import Any - -from pandasai.helpers.logger import Logger -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - -from ...helpers.output_validator import OutputValidator -from ..base_logic_unit import BaseLogicUnit -from ..pipeline_context import PipelineContext - - -class ResultValidation(BaseLogicUnit): - """ - Result Validation Stage - """ - - pass - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - - result = input - success = False - message = None - if result is not None: - if isinstance(result, dict): - ( - validation_ok, - validation_logs, - ) = OutputValidator.validate( - pipeline_context.get("output_type"), result - ) - if not validation_ok: - logger.log("\n".join(validation_logs), level=logging.WARNING) - success = False - message = "Output Validation Failed" - - else: - success = True - message = "Output Validation Successful" - - pipeline_context.add("last_result", result) - logger.log(f"Answer: {result}") - - return LogicUnitOutput(result, success, message) diff --git a/pandasai/pipelines/judge/judge_pipeline_input.py b/pandasai/pipelines/judge/judge_pipeline_input.py deleted file mode 100644 index aaceea15c..000000000 --- a/pandasai/pipelines/judge/judge_pipeline_input.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class JudgePipelineInput: - query: str - code: str - - def __init__(self, query: str, code: str) -> None: - self.query = query - self.code = code diff --git a/pandasai/pipelines/logic_unit_output.py b/pandasai/pipelines/logic_unit_output.py deleted file mode 100644 index bbbca7609..000000000 --- a/pandasai/pipelines/logic_unit_output.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass -from typing import Any - - -@dataclass -class LogicUnitOutput: - """ - Pipeline step output - """ - - output: Any - message: str - success: bool - metadata: dict - final_track_output: bool - - def __init__( - self, - output: Any = None, - success: bool = False, - message: str = None, - metadata: dict = None, - final_track_output: bool = False, - ): - self.output = output - self.message = message - self.metadata = metadata - self.success = success - self.final_track_output = final_track_output diff --git a/pandasai/pipelines/logic_units/code_executor.py b/pandasai/pipelines/logic_units/code_executor.py deleted file mode 100644 index 82c6be639..000000000 --- a/pandasai/pipelines/logic_units/code_executor.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any - -from pandasai.pipelines.base_logic_unit import BaseLogicUnit - - -class BaseCodeExecutor(BaseLogicUnit): - """ - Executes the code generated by the prompt - """ - - def execute(self, input: Any, **kwargs) -> Any: - # Create an empty namespace dictionary - namespace = {} - - # Execute the code to populate the namespace - exec(input, namespace) - - return namespace diff --git a/pandasai/pipelines/logic_units/output_logic_unit.py b/pandasai/pipelines/logic_units/output_logic_unit.py deleted file mode 100644 index c13878299..000000000 --- a/pandasai/pipelines/logic_units/output_logic_unit.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Any, Type, TypeVar - -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.responses.context import Context -from pandasai.responses.response_parser import IResponseParser, ResponseParser - -IResponseParserImplementation = TypeVar( - "IResponseParserImplementation", bound=IResponseParser -) - - -class ProcessOutput(BaseLogicUnit): - """ - Executes the code generated by the prompt - """ - - _response_parser: Type[IResponseParserImplementation] - - def __init__( - self, response_parser: Type[IResponseParserImplementation] = ResponseParser - ): - super().__init__() - self._response_parser = response_parser - - def execute(self, input: Any, **kwargs) -> Any: - dfs = kwargs["context"].dfs - - context = Context(kwargs["config"], kwargs["logger"], dfs[0].engine) - - return self._response_parser(context).parse(input) diff --git a/pandasai/pipelines/logic_units/prompt_execution.py b/pandasai/pipelines/logic_units/prompt_execution.py deleted file mode 100644 index 29fbe14b9..000000000 --- a/pandasai/pipelines/logic_units/prompt_execution.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Any - -from pandasai.exceptions import LLMNotFoundError -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.prompts.file_based_prompt import FileBasedPrompt - - -class PromptExecution(BaseLogicUnit): - def execute(self, input: FileBasedPrompt, **kwargs) -> Any: - config = kwargs.get("config") - if config is None or getattr(config, "llm", None) is None: - raise LLMNotFoundError() - llm = getattr(config, "llm") - return llm.call(input) diff --git a/pandasai/pipelines/pipeline.py b/pandasai/pipelines/pipeline.py deleted file mode 100644 index c23ca76fb..000000000 --- a/pandasai/pipelines/pipeline.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations -import logging -from typing import TYPE_CHECKING, Any, List, Optional, Union - -from pandasai.config import load_config_from_json - -from pandasai.exceptions import PipelineConcatenationError, UnSupportedLogicUnit -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - -from ..schemas.df_config import Config -from .abstract_pipeline import AbstractPipeline - -if TYPE_CHECKING: - from pandasai.dataframe.base import DataFrame - - -class Pipeline(AbstractPipeline): - """ - Base Pipeline class to be used to create custom pipelines - """ - - _context: PipelineContext - _logger: Logger - _steps: List[BaseLogicUnit] - - def __init__( - self, - context: Union[List[DataFrame], PipelineContext], - config: Optional[Union[Config, dict]] = None, - steps: Optional[List] = None, - logger: Optional[Logger] = None, - ): - """ - Initialize the pipeline with given context and configuration - parameters. - Args : - context (Context) : Context is required for ResponseParsers. - config (dict) : The configuration to pipeline. - steps: (list): List of logic Units - logger: (Logger): logger - """ - - if context and not isinstance(context, PipelineContext): - config = Config(**load_config_from_json(config)) - connectors = context - context = PipelineContext(connectors, config) - - self._logger = ( - Logger(save_logs=context.config.save_logs, verbose=context.config.verbose) - if logger is None and context - else logger - ) - - self._context = context - self._steps = steps or [] - - def add_step(self, logic: BaseLogicUnit): - """ - Adds new logics in the pipeline - Args: - logic (BaseLogicUnit): execution unit of logic - """ - if not isinstance(logic, BaseLogicUnit): - raise UnSupportedLogicUnit( - "Logic unit must be inherited from BaseLogicUnit and " - "must implement execute method" - ) - - self._steps.append(logic) - - def run(self, data: Any = None) -> Any: - """ - This functions is responsible to loop through logics - Args: - data (Any, optional): Input Data to run the pipeline. Defaults to None. - - Returns: - Any: Depends on the type can return anything - """ - try: - for index, logic in enumerate(self._steps): - # Callback function before execution - if logic.before_execution is not None: - logic.before_execution(data) - - self._logger.log(f"Executing Step {index}: {logic.__class__.__name__}") - - if logic.skip_if is not None and logic.skip_if(self._context): - self._logger.log(f"Executing Step {index}: Skipping...") - continue - - # Execute the logic unit - step_output = logic.execute( - data, - logger=self._logger, - config=self._context.config, - context=self._context, - ) - - # Track the execution step of pipeline - if isinstance(step_output, LogicUnitOutput): - data = step_output.output - else: - data = step_output - - # Callback function after execution - if logic.on_execution is not None: - logic.on_execution(data) - - except Exception as e: - self._logger.log(f"Pipeline failed on step {index}: {e}", logging.ERROR) - raise e - - return data - - def __or__(self, pipeline: "Pipeline") -> Any: - """ - This functions is responsible to pipe two pipelines - Args: - pipeline (Pipeline): Second Pipeline which will be Piped to the self. - data (Any, optional): Input Data to run the pipeline. Defaults to None. - - Returns: - Any: Depends on the type can return anything - """ - - if not isinstance(pipeline, Pipeline): - raise PipelineConcatenationError( - "Pipeline can be concatenated with Pipeline class only!" - ) - - combined_pipeline = Pipeline( - context=self._context, - logger=self._logger, - ) - - for step in self._steps: - combined_pipeline.add_step(step) - - for step in pipeline._steps: - combined_pipeline.add_step(step) - - return combined_pipeline - - @property - def context(self): - return self._context - - @context.setter - def context(self, context: PipelineContext): - self._context = context - - @property - def logger(self): - return self._logger - - @logger.setter - def logger(self, logger: Logger): - self._logger = logger diff --git a/pandasai/pipelines/pipeline_context.py b/pandasai/pipelines/pipeline_context.py deleted file mode 100644 index e2bab2f64..000000000 --- a/pandasai/pipelines/pipeline_context.py +++ /dev/null @@ -1,53 +0,0 @@ -import pandas as pd -from typing import Any, List, Optional, Union - -from pandasai.helpers.cache import Cache -from pandasai.helpers.memory import Memory -from pandasai.schemas.df_config import Config -from pandasai.vectorstores.vectorstore import VectorStore - - -class PipelineContext: - """ - Pass Context to the pipeline which is accessible to each step via kwargs - """ - - def __init__( - self, - dfs: List[pd.DataFrame], - config: Optional[Union[Config, dict]] = None, - memory: Optional[Memory] = None, - cache: Optional[Cache] = None, - vectorstore: VectorStore = None, - initial_values: dict = None, - ) -> None: - if isinstance(config, dict): - config = Config(**config) - - self.dfs = dfs - self.memory = memory or Memory() - - if config.enable_cache: - self.cache = cache if cache is not None else Cache() - else: - self.cache = None - - self.config = config - - self.intermediate_values = initial_values or {} - - self.vectorstore = vectorstore - - self._initial_values = initial_values - - def reset_intermediate_values(self): - self.intermediate_values = self._initial_values or {} - - def add(self, key: str, value: Any): - self.intermediate_values[key] = value - - def add_many(self, values: dict): - self.intermediate_values.update(values) - - def get(self, key: str, default: Any = ""): - return self.intermediate_values.get(key, default) diff --git a/pandasai/prompts/__init__.py b/pandasai/prompts/__init__.py deleted file mode 100644 index bd61e860a..000000000 --- a/pandasai/prompts/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .base import BasePrompt -from .correct_error_prompt import CorrectErrorPrompt -from .generate_python_code import GeneratePythonCodePrompt - -__all__ = [ - "BasePrompt", - "CorrectErrorPrompt", - "GeneratePythonCodePrompt", -] diff --git a/pandasai/prompts/templates/shared/dataframe.tmpl b/pandasai/prompts/templates/shared/dataframe.tmpl deleted file mode 100644 index 178df1093..000000000 --- a/pandasai/prompts/templates/shared/dataframe.tmpl +++ /dev/null @@ -1 +0,0 @@ -{{ df.serialize_dataframe(index-1, context.config.direct_sql, context.config.dataframe_serializer, context.config.enforce_privacy) }} diff --git a/pandasai/responses/__init__.py b/pandasai/responses/__init__.py deleted file mode 100644 index 79fa594e0..000000000 --- a/pandasai/responses/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Response Parsers for the user to customize response returned from the chat method -""" -from .context import Context -from .response_parser import IResponseParser, ResponseParser - -__all__ = ["IResponseParser", "ResponseParser", "Context"] diff --git a/pandasai/responses/context.py b/pandasai/responses/context.py deleted file mode 100644 index 31718a60e..000000000 --- a/pandasai/responses/context.py +++ /dev/null @@ -1,26 +0,0 @@ -from pandasai.helpers.logger import Logger -from pandasai.schemas.df_config import Config - - -class Context: - """ - Context class that contains context from Agent for ResponseParsers - Context contain the application config and logger. - """ - - _config = None - _logger = None - - def __init__(self, config: Config, logger: Logger = None) -> None: - self._config = config - self._logger = logger - - @property - def config(self): - """Getter for _config attribute.""" - return self._config - - @property - def logger(self): - """Getter for _logger attribute.""" - return self._logger diff --git a/pandasai/responses/response_parser.py b/pandasai/responses/response_parser.py deleted file mode 100644 index fd202784d..000000000 --- a/pandasai/responses/response_parser.py +++ /dev/null @@ -1,76 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - -from PIL import Image - -from pandasai.exceptions import MethodNotImplementedError - - -class IResponseParser(ABC): - @abstractmethod - def parse(self, result: dict) -> Any: - """ - Parses result from the chat input - Args: - result (dict): result contains type and value - Raises: - ValueError: if result is not a dictionary with valid key - - Returns: - Any: Returns depending on the user input - """ - raise MethodNotImplementedError - - -class ResponseParser(IResponseParser): - _context = None - - def __init__(self, context) -> None: - """ - Initialize the ResponseParser with Context from Agent - Args: - context (Context): context contains the config and logger - """ - self._context = context - - def parse(self, result: dict) -> Any: - """ - Parses result from the chat input - Args: - result (dict): result contains type and value - Raises: - ValueError: if result is not a dictionary with valid key - - Returns: - Any: Returns depending on the user input - """ - if not isinstance(result, dict) or any( - key not in result for key in ["type", "value"] - ): - raise ValueError("Unsupported result format") - - if result["type"] == "plot": - return self.format_plot(result) - else: - return result["value"] - - def format_plot(self, result: dict) -> Any: - """ - Display matplotlib plot against a user query. - - If `open_charts` option set to `False`, the chart won't be displayed. - - Args: - result (dict): result contains type and value - Returns: - Any: Returns depending on the user input - """ - if ( - self._context._config.open_charts - and isinstance(result["value"], str) - and "data:image/png;base64" not in result["value"] - ): - with Image.open(result["value"]) as img: - img.show() - - return result["value"] diff --git a/pandasai/responses/response_serializer.py b/pandasai/responses/response_serializer.py deleted file mode 100644 index 86f6f5f1e..000000000 --- a/pandasai/responses/response_serializer.py +++ /dev/null @@ -1,46 +0,0 @@ -import base64 -import json - -import pandas as pd -from pandasai.responses.response_type import ResponseType - - -class ResponseSerializer: - @staticmethod - def serialize_dataframe(df: pd.DataFrame): - json_data = json.loads(df.to_json(orient="split", date_format="iso")) - return {"headers": json_data["columns"], "rows": json_data["data"]} - - @staticmethod - def serialize(result: ResponseType) -> ResponseType: - """ - Format output response - Args: - result (ResponseType): response returned after execution - - Returns: - ResponseType: formatted response output - """ - if result["type"] == "dataframe": - if isinstance(result["value"], pd.Series): - result["value"] = result["value"].to_frame() - df_dict = ResponseSerializer.serialize_dataframe(result["value"]) - return {"type": result["type"], "value": df_dict} - - elif result["type"] == "plot" and isinstance(result["value"], str): - # check if already in base64 str return - if "data:image/png;base64" in result["value"]: - return result - - with open(result["value"], "rb") as image_file: - image_data = image_file.read() - # Encode the image data to Base64 - base64_image = ( - f"data:image/png;base64,{base64.b64encode(image_data).decode()}" - ) - return { - "type": result["type"], - "value": base64_image, - } - else: - return result diff --git a/pandasai/responses/response_type.py b/pandasai/responses/response_type.py deleted file mode 100644 index c3ab44e45..000000000 --- a/pandasai/responses/response_type.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Any, TypedDict - - -class ResponseType(TypedDict): - type: str - value: Any diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py deleted file mode 100644 index dffee8a14..000000000 --- a/pandasai/schemas/df_config.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Any, List, Optional, Dict -from pydantic import BaseModel, Field, field_validator, ConfigDict - -from pandasai.constants import DEFAULT_CHART_DIRECTORY -from pandasai.helpers.dataframe_serializer import DataframeSerializerType - -from ..llm import LLM, BambooLLM -from importlib.util import find_spec - - -class LogServerConfig(BaseModel): - server_url: str - api_key: str - - -class Config(BaseModel): - save_logs: bool = True - verbose: bool = False - enforce_privacy: bool = False - enable_cache: bool = True - use_error_correction_framework: bool = True - open_charts: bool = True - save_charts: bool = False - save_charts_path: str = DEFAULT_CHART_DIRECTORY - custom_whitelisted_dependencies: List[str] = Field(default_factory=list) - max_retries: int = 3 - response_parser: Any = None - llm: LLM = Field( - default_factory=lambda: BambooLLM( - api_key="dummy_key_for_testing", endpoint_url=None - ) - ) - data_viz_library: Optional[str] = "" - log_server: Optional[LogServerConfig] = None - direct_sql: bool = False - dataframe_serializer: DataframeSerializerType = DataframeSerializerType.CSV - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @field_validator("llm", mode="before") - @classmethod - def validate_llm(cls, v: Any) -> LLM: - if v is None: - return BambooLLM(api_key="dummy_key_for_testing", endpoint_url=None) - if find_spec("pandasai_langchain") is not None: - from pandasai_langchain.langchain import LangchainLLM - - if not isinstance(v, (LLM, LangchainLLM)): - return BambooLLM(api_key="dummy_key_for_testing", endpoint_url=None) - elif not isinstance(v, LLM): - return BambooLLM(api_key="dummy_key_for_testing", endpoint_url=None) - return v - - @classmethod - def from_dict(cls, config: Dict[str, Any]) -> "Config": - return cls(**config) diff --git a/pandasai/vectorstores/bamboo_vectorstore.py b/pandasai/vectorstores/bamboo_vectorstore.py index b950f0078..d7421db09 100644 --- a/pandasai/vectorstores/bamboo_vectorstore.py +++ b/pandasai/vectorstores/bamboo_vectorstore.py @@ -1,3 +1,4 @@ +import logging from typing import Iterable, List, Optional, Union from pandasai.helpers.logger import Logger @@ -60,7 +61,7 @@ def get_relevant_qa_documents(self, question: str, k: int = None) -> List[dict]: ) return docs["docs"] except Exception: - self._logger.log("Querying without using training data.") + self._logger.log("Querying without using training data.", logging.ERROR) return [] def get_relevant_docs_documents( @@ -79,5 +80,5 @@ def get_relevant_docs_documents( ) return docs["docs"] except Exception: - self._logger.log("Querying without using training docs.") + self._logger.log("Querying without using training docs.", logging.ERROR) return []