diff --git a/docs/API/llms.md b/docs/API/llms.md index 5983dfe51..e05f5c5c0 100644 --- a/docs/API/llms.md +++ b/docs/API/llms.md @@ -18,18 +18,22 @@ OpenAI API wrapper extended through BaseOpenAI class. options: show_root_heading: true -### Starcoder +### Starcoder (deprecated) Starcoder wrapper extended through Base HuggingFace Class +- Note: Starcoder is deprecated and will be removed in future versions. Please use another LLM. + ::: pandasai.llm.starcoder options: show_root_heading: true -### Falcon +### Falcon (deprecated) Falcon wrapper extended through Base HuggingFace Class +- Note: Falcon is deprecated and will be removed in future versions. Please use another LLM. + ::: pandasai.llm.falcon options: show_root_heading: true diff --git a/docs/LLMs/llms.md b/docs/LLMs/llms.md index e469947bd..5abe5fb2b 100644 --- a/docs/LLMs/llms.md +++ b/docs/LLMs/llms.md @@ -2,6 +2,8 @@ PandasAI supports several large language models (LLMs). LLMs are used to generate code from natural language queries. The generated code is then executed to produce the result. +[![Choose the LLM](https://cdn.loom.com/sessions/thumbnails/5496c9c07ee04f69bfef1bc2359cd591-00001.jpg)](https://www.loom.com/share/5496c9c07ee04f69bfef1bc2359cd591 "Choose the LLM") + You can either choose a LLM by instantiating one and passing it to the `SmartDataFrame` or `SmartDatalake` constructor, or you can specify one in the `pandasai.json` file. If the model expects one or more parameters, you can pass them to the constructor or specify them in the `pandasai.json` file, in the `llm_options` param, as it follows: @@ -15,8 +17,6 @@ If the model expects one or more parameters, you can pass them to the constructo } ``` -## OpenAI models - In order to use OpenAI models, you need to have an OpenAI API key. You can get one [here](https://platform.openai.com/account/api-keys). Once you have an API key, you can use it to instantiate an OpenAI object: diff --git a/mkdocs.yml b/mkdocs.yml index d35692c6d..a006a0370 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,7 +37,7 @@ nav: - Documents Building: building_docs.md - License: license.md extra: - version: "1.1" + version: "1.1.2" plugins: - search - mkdocstrings: diff --git a/pandasai/__init__.py b/pandasai/__init__.py index f3539f89d..071ccbf1f 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -98,22 +98,22 @@ class PandasAI: """ _dl: SmartDatalake = None - _config: Config + _config: Union[Config, dict] def __init__( - self, - llm=None, - conversational=False, - verbose=False, - enforce_privacy=False, - save_charts=False, - save_charts_path="", - enable_cache=True, - middlewares=None, - custom_whitelisted_dependencies=None, - enable_logging=True, - non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None, - callback: Optional[BaseCallback] = None, + self, + llm=None, + conversational=False, + verbose=False, + enforce_privacy=False, + save_charts=False, + save_charts_path="", + enable_cache=True, + middlewares=None, + custom_whitelisted_dependencies=None, + enable_logging=True, + non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None, + callback: Optional[BaseCallback] = None, ): """ __init__ method of the Class PandasAI @@ -142,8 +142,10 @@ def __init__( # noinspection PyArgumentList # https://stackoverflow.com/questions/61226587/pycharm-does-not-recognize-logging-basicconfig-handlers-argument - warnings.warn("`PandasAI` (class) is deprecated since v1.0 and will be removed " - "in a future release. Please use `SmartDataframe` instead.") + warnings.warn( + "`PandasAI` (class) is deprecated since v1.0 and will be removed " + "in a future release. Please use `SmartDataframe` instead." + ) self._config = Config( conversational=conversational, @@ -161,12 +163,12 @@ def __init__( ) def run( - self, - data_frame: Union[pd.DataFrame, List[pd.DataFrame]], - prompt: str, - show_code: bool = False, - anonymize_df: bool = True, - use_error_correction_framework: bool = True, + self, + data_frame: Union[pd.DataFrame, List[pd.DataFrame]], + prompt: str, + show_code: bool = False, + anonymize_df: bool = True, + use_error_correction_framework: bool = True, ) -> Union[str, pd.DataFrame]: """ Run the PandasAI to make Dataframes Conversational. @@ -198,12 +200,12 @@ def run( return self._dl.chat(prompt) def __call__( - self, - data_frame: Union[pd.DataFrame, List[pd.DataFrame]], - prompt: str, - show_code: bool = False, - anonymize_df: bool = True, - use_error_correction_framework: bool = True, + self, + data_frame: Union[pd.DataFrame, List[pd.DataFrame]], + prompt: str, + show_code: bool = False, + anonymize_df: bool = True, + use_error_correction_framework: bool = True, ) -> Union[str, pd.DataFrame]: """ __call__ method of PandasAI class. It calls the `run` method. diff --git a/pandasai/config.py b/pandasai/config.py index dd1234d6e..f946e02f6 100644 --- a/pandasai/config.py +++ b/pandasai/config.py @@ -1,10 +1,15 @@ import json +import logging +from typing import Optional, Union + from . import llm, middlewares, callbacks from .helpers.path import find_closest from .schemas.df_config import Config +logger = logging.getLogger(__name__) + -def load_config(override_config: Config = None): +def load_config(override_config: Optional[Union[Config, dict]] = None): config = {} if override_config is None: @@ -27,11 +32,9 @@ def load_config(override_config: Config = None): if config.get("callback") and not override_config.get("callback"): config["callback"] = getattr(callbacks, config["callback"])() except Exception: - pass + logger.error("Could not load configuration", exc_info=True) if override_config: config.update(override_config) - config = Config(**config) - return config diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index fd43df3ab..80aaf6c38 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -25,7 +25,7 @@ class CodeManager: _dfs: List _middlewares: List[Middleware] = [ChartsMiddleware()] - _config: Config + _config: Union[Config, dict] _logger: Logger = None _additional_dependencies: List[dict] = [] _ast_comparatos_map: dict = { @@ -46,12 +46,12 @@ class CodeManager: def __init__( self, dfs: List, - config: Config, + config: Union[Config, dict], logger: Logger, ): """ Args: - config (Config, optional): Config to be used. Defaults to None. + config (Union[Config, dict], optional): Config to be used. Defaults to None. logger (Logger, optional): Logger to be used. Defaults to None. """ diff --git a/pandasai/helpers/df_info.py b/pandasai/helpers/df_info.py index ac3de91f3..139bd609b 100644 --- a/pandasai/helpers/df_info.py +++ b/pandasai/helpers/df_info.py @@ -21,6 +21,10 @@ def df_type(df: DataFrameType) -> str: Returns: str: Type of the dataframe """ + print("*" * 100) + print(df) + print("*" * 100) + if polars_imported and isinstance(df, pl.DataFrame): return "polars" elif isinstance(df, pd.DataFrame): diff --git a/pandasai/helpers/df_validator.py b/pandasai/helpers/df_validator.py new file mode 100644 index 000000000..03c9bf999 --- /dev/null +++ b/pandasai/helpers/df_validator.py @@ -0,0 +1,126 @@ +from typing import List, Dict +from pydantic import ValidationError +from pydantic import BaseModel +from pandasai.helpers.df_info import DataFrameType, df_type + + +class DfValidationResult: + """ + Validation results for a dataframe. + + Attributes: + passed: Whether the validation passed or not. + errors: List of errors if the validation failed. + """ + + _passed: bool + _errors: List[Dict] + + def __init__(self, passed: bool = True, errors: List[Dict] = None): + """ + Args: + passed: Whether the validation passed or not. + errors: List of errors if the validation failed. + """ + if errors is None: + errors = [] + self._passed = passed + self._errors = errors + + @property + def passed(self): + return self._passed + + def errors(self) -> List[Dict]: + return self._errors + + def add_error(self, error_message: str): + """ + Add an error message to the validation results. + + Args: + error_message: Error message to add. + """ + self._passed = False + self._errors.append(error_message) + + def __bool__(self) -> bool: + """ + Define the truthiness of ValidationResults. + """ + return self.passed + + +class DfValidator: + """ + Validate a dataframe using a Pydantic schema. + + Attributes: + df: dataframe to be validated + """ + + _df: DataFrameType + + def __init__(self, df: DataFrameType): + """ + Args: + df: dataframe to be validated + """ + self._df = df + + def _validate_batch(self, schema, df_json: List[Dict]): + """ + Args: + schema: Pydantic schema + batch_df: dataframe batch + + Returns: + list of errors + """ + try: + # Create a Pydantic Validator to validate rows of dataframe + class PdVal(BaseModel): + df: List[schema] + + PdVal(df=df_json) + return [] + + except ValidationError as e: + return e.errors() + + def _df_to_list_of_dict(self, df: DataFrameType, dataframe_type: str) -> List[Dict]: + """ + Create list of dict of dataframe rows on basis of dataframe type + Supports only polars and pandas dataframe + + Args: + df: dataframe to be converted + dataframe_type: type of dataframe + + Returns: + list of dict of dataframe rows + """ + if dataframe_type == "pandas": + return df.to_dict(orient="records") + elif dataframe_type == "polars": + return df.to_dicts() + else: + return [] + + def validate(self, schema: BaseModel) -> DfValidationResult: + """ + Args: + schema: Pydantic schema to be validated for the dataframe row + + Returns: + Validation results + """ + dataframe_type = df_type(self._df) + if dataframe_type is None: + raise ValueError("Unsupported DataFrame") + + df_json: List[Dict] = self._df_to_list_of_dict(self._df, dataframe_type) + + errors = self._validate_batch(schema, df_json) + + return DfValidationResult(len(errors) == 0, errors) diff --git a/pandasai/llm/falcon.py b/pandasai/llm/falcon.py index 9fabd01c3..e722f4663 100644 --- a/pandasai/llm/falcon.py +++ b/pandasai/llm/falcon.py @@ -8,7 +8,7 @@ >>> from pandasai.llm.falcon import Falcon """ - +import warnings from ..helpers import load_dotenv from .base import HuggingFaceLLM @@ -17,12 +17,7 @@ class Falcon(HuggingFaceLLM): - - """Falcon LLM API - - A base HuggingFaceLLM class is extended to use Falcon model. - - """ + """Falcon LLM API (Deprecated: Kept for backwards compatibility)""" api_token: str _api_url: str = ( @@ -30,6 +25,15 @@ class Falcon(HuggingFaceLLM): ) _max_retries: int = 30 + def __init__(self, **kwargs): + warnings.warn( + """Falcon is deprecated and will be removed in a future release. + Please use langchain.llms.HuggingFaceHub instead, although please be + aware that it may perform poorly. + """ + ) + super().__init__(**kwargs) + @property def type(self) -> str: return "falcon" diff --git a/pandasai/llm/starcoder.py b/pandasai/llm/starcoder.py index 3546116c1..c84d227b6 100644 --- a/pandasai/llm/starcoder.py +++ b/pandasai/llm/starcoder.py @@ -8,7 +8,7 @@ >>> from pandasai.llm.starcoder import Starcoder """ - +import warnings from ..helpers import load_dotenv from .base import HuggingFaceLLM @@ -18,16 +18,21 @@ class Starcoder(HuggingFaceLLM): - """Starcoder LLM API - - A base HuggingFaceLLM class is extended to use Starcoder model. - - """ + """Starcoder LLM API (Deprecated: Kept for backwards compatibility)""" api_token: str _api_url: str = "https://api-inference.huggingface.co/models/bigcode/starcoder" _max_retries: int = 30 + def __init__(self, **kwargs): + warnings.warn( + """Starcoder is deprecated and will be removed in a future release. + Please use langchain.llms.HuggingFaceHub instead, although please be + aware that it may perform poorly. + """ + ) + super().__init__(**kwargs) + @property def type(self) -> str: return "starcoder" diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index bff66e142..902f2c841 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -51,7 +51,7 @@ class GeneratePythonCodePrompt(Prompt): # Analyze the data # 1. Prepare: Preprocessing and cleaning data if necessary # 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in exports/charts/temp_chart.png and do not show the chart.) +# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.) # 4. Output: return a dictionary of: # - type (possible values "text", "number", "dataframe", "plot") # - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index f9bed68ec..6e72fd14a 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -23,6 +23,10 @@ import pandas as pd from functools import cached_property +import pydantic + +from pandasai.helpers.df_validator import DfValidator + from ..smart_datalake import SmartDatalake from ..schemas.df_config import Config from ..helpers.data_sampler import DataSampler @@ -436,6 +440,16 @@ def connector(self, connector: BaseConnector): connector.logger = self.logger self._core.connector = connector + def validate(self, schema: pydantic.BaseModel): + """ + Validates Dataframe rows on the basis Pydantic schema input + (Args): + schema: Pydantic schema class + verbose: Print Errors + """ + df_validator = DfValidator(self.dataframe) + return df_validator.validate(schema) + @property def lake(self) -> SmartDatalake: return self._lake diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index be0c9c514..879f46a27 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -44,7 +44,7 @@ class SmartDatalake: _dfs: List[DataFrameType] - _config: Config + _config: Union[Config, dict] _llm: LLM _cache: Cache = None _logger: Logger @@ -60,14 +60,14 @@ class SmartDatalake: def __init__( self, dfs: List[Union[DataFrameType, Any]], - config: Config = None, + config: Optional[Union[Config, dict]] = None, logger: Logger = None, memory: Memory = None, ): """ Args: dfs (List[Union[DataFrameType, Any]]): List of dataframes to be used - config (Config, optional): Config to be used. Defaults to None. + config (Union[Config, dict], optional): Config to be used. Defaults to None. logger (Logger, optional): Logger to be used. Defaults to None. """ @@ -135,18 +135,21 @@ def _load_dfs(self, dfs: List[Union[DataFrameType, Any]]): smart_dfs.append(df) self._dfs = smart_dfs - def _load_config(self, config: Config): + def _load_config(self, config: Union[Config, dict]): """ Load a config to be used to run the queries. Args: - config (Config): Config to be used + config (Union[Config, dict]): Config to be used """ - self._config = load_config(config) + config = load_config(config) - if self._config.llm: - self._load_llm(self._config.llm) + if config.get("llm"): + self._load_llm(config["llm"]) + config["llm"] = self._llm + + self._config = Config(**config) def _load_llm(self, llm: LLM): """ @@ -161,10 +164,7 @@ def _load_llm(self, llm: LLM): BadImportError: If the LLM is a Langchain LLM but the langchain package is not installed """ - - try: - llm.is_pandasai_llm() - except AttributeError: + if hasattr(llm, "_llm_type"): llm = LangchainLLM(llm) self._llm = llm @@ -283,6 +283,7 @@ def chat(self, query: str): default_values = { # TODO: find a better way to determine the engine, "engine": self._dfs[0].engine, + "save_charts_path": self._config.save_charts_path.rstrip("/"), } generate_python_code_instruction = self._get_prompt( "generate_python_code", diff --git a/pyproject.toml b/pyproject.toml index dd8a3998c..979526144 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pandasai" -version = "1.1" +version = "1.1.2" description = "PandasAI is a Python library that integrates generative artificial intelligence capabilities into Pandas, making dataframes conversational." authors = ["Gabriele Venturi"] license = "MIT" diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index e5597ba14..889723ddf 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -23,6 +23,7 @@ def test_str_with_args(self): prompt = GeneratePythonCodePrompt() prompt.set_var("dfs", dfs) prompt.set_var("conversation", "Question") + prompt.set_var("save_charts_path", "exports/charts") assert ( prompt.to_string() == """ @@ -51,6 +52,61 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: # Code goes here (do not add comments) +# Declare a result variable +result = analyze_data(dfs) +``` + +Using the provided dataframes (`dfs`), update the python code based on the last user question: +Question + +Updated code: +""" # noqa: E501 + ) + + def test_str_with_custom_save_charts_path(self): + """Test that the __str__ method is implemented""" + + llm = FakeLLM("plt.show()") + dfs = [ + SmartDataframe( + pd.DataFrame({"a": [1], "b": [4]}), + config={"llm": llm}, + ) + ] + + prompt = GeneratePythonCodePrompt() + prompt.set_var("dfs", dfs) + prompt.set_var("conversation", "Question") + prompt.set_var("save_charts_path", "custom_path") + + assert ( + prompt.to_string() + == """ +You are provided with the following pandas DataFrames with the following metadata: + +Dataframe dfs[0], with 1 rows and 2 columns. +This is the metadata of the dataframe dfs[0]: +a,b +1,4 + + +This is the initial python code to be updated: +```python +# TODO import all the dependencies required +import pandas as pd + +# Analyze the data +# 1. Prepare: Preprocessing and cleaning data if necessary +# 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) +# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in custom_path/temp_chart.png and do not show the chart.) +# 4. Output: return a dictionary of: +# - type (possible values "text", "number", "dataframe", "plot") +# - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) +# Example output: { "type": "text", "value": "The average loan amount is $15,000." } +def analyze_data(dfs: list[pd.DataFrame]) -> dict: + # Code goes here (do not add comments) + + # Declare a result variable result = analyze_data(dfs) ``` diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index fb8dc0f3d..1f2c8a0fc 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -8,6 +8,7 @@ import pandas as pd import polars as pl +from pydantic import BaseModel, Field import pytest from pandasai import SmartDataframe @@ -123,7 +124,7 @@ def test_init(self, smart_dataframe): def test_init_without_llm(self, sample_df): with pytest.raises(LLMNotFoundError): - SmartDataframe(sample_df) + SmartDataframe(sample_df, config={"llm": None}) def test_run(self, smart_dataframe: SmartDataframe, llm): llm._output = ( @@ -240,6 +241,71 @@ def test_getters_are_accessible(self, smart_dataframe: SmartDataframe, llm): == "def analyze_data(dfs):\n return {'type': 'number', 'value': 1}" ) + def test_save_chart_non_default_dir( + self, smart_dataframe: SmartDataframe, llm, sample_df + ): + """ + Test chat with `SmartDataframe` with custom `save_charts_path`. + + Script: + 1) Ask `SmartDataframe` to build a chart and save it in + a custom directory; + 2) Check if substring representing the directory present in + `llm.last_prompt`. + 3) Check if the code has had a call of `plt.savefig()` passing + the custom directory. + + Notes: + 1) Mock `import_dependency()` util-function to avoid the + actual calls to `matplotlib.pyplot`. + 2) The `analyze_data()` function in the code fixture must have + `"type": None` in the result dict. Otherwise, if it had + `"type": "plot"` (like it has in practice), `_format_results()` + method from `SmartDatalake` object would try to read the image + with `matplotlib.image.imread()` and this test would fail. + Those calls to `matplotlib.image` are unmockable because of + imports inside the function scope, not in the top of a module. + @TODO: figure out if we can just move the imports beyond to + make it possible to mock out `matplotlib.image` + """ + llm._output = """ +import pandas as pd +import matplotlib.pyplot as plt +def analyze_data(dfs: list[pd.DataFrame]) -> dict: + df = dfs[0].nlargest(5, 'happiness_index') + + plt.figure(figsize=(8, 6)) + plt.pie(df['happiness_index'], labels=df['country'], autopct='%1.1f%%') + plt.title('Happiness Index for the 5 Happiest Countries') + plt.savefig('custom-dir/output_charts/temp_chart.png') + plt.close() + + return {"type": None, "value": "custom-dir/output_charts/temp_chart.png"} +result = analyze_data(dfs) +""" + with patch( + "pandasai.helpers.code_manager.import_dependency" + ) as import_dependency_mock: + smart_dataframe = SmartDataframe( + sample_df, + config={ + "llm": llm, + "enable_cache": False, + "save_charts": True, + "save_charts_path": "custom-dir/output_charts/", + }, + ) + + smart_dataframe.chat("Plot pie-chart the 5 happiest countries") + + assert "custom-dir/output_charts/temp_chart.png" in llm.last_prompt + plt_mock = getattr(import_dependency_mock.return_value, "matplotlib.pyplot") + assert plt_mock.savefig.called + assert ( + plt_mock.savefig.call_args.args[0] + == "custom-dir/output_charts/temp_chart.png" + ) + def test_add_middlewares(self, smart_dataframe: SmartDataframe, custom_middleware): middleware = custom_middleware() smart_dataframe.add_middlewares(middleware) @@ -606,7 +672,7 @@ def test_save_pandas_dataframe_duplicate_name(self, llm): # Create a sample DataFrame df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - # Create instances of YourDataFrameClass + # Create instances of SmartDataframe df_object1 = SmartDataframe( df, name="df_duplicate", @@ -638,7 +704,7 @@ def test_save_pandas_no_name(self, llm): # Create a sample DataFrame df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) - # Create an instance of YourDataFrameClass without a name + # Create an instance of SmartDataframe without a name df_object = SmartDataframe( df, description="No Name", config={"llm": llm, "enable_cache": False} ) @@ -662,3 +728,101 @@ def test_save_pandas_no_name(self, llm): # Recover file for next test case with open("pandasai.json", "w") as json_file: json_file.write(backup_pandasai) + + def test_pydantic_validate(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is True + + def test_pydantic_validate_false(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": ["Test", "Test2", "Test3", "Test4"], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is False + + def test_pydantic_validate_polars(self, llm): + # Create a sample DataFrame + df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + assert validation_result.passed is True + + def test_pydantic_validate_false_one_record(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": [1, "test", 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + assert ( + validation_result.passed is False and len(validation_result.errors()) == 1 + ) + + def test_pydantic_validate_complex_schema(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int = Field(..., gt=5) + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is False + + class TestSchema(BaseModel): + A: int = Field(..., lt=5) + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is True