diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index f8c2afff6..bf604a575 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -391,3 +391,25 @@ def __init__(self, config: ConnectorConfig): config["password"] = os.getenv("POSTGRESQL_PASSWORD") super().__init__(config) + + @cache + def head(self): + """ + Return the head of the data source that the connector is connected to. + This information is passed to the LLM to provide the schema of the data source. + + Returns: + DataFrame: The head of the data source. + """ + + if self.logger: + self.logger.log( + f"Getting head of {self._config.table} " + f"using dialect {self._config.dialect}" + ) + + # Run a SQL query to get all the columns names and 5 random rows + query = self._build_query(limit=5, order="RANDOM()") + + # Return the head of the data source + return pd.read_sql(query, self._connection)