diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 60cf5ec8c..91da4c384 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -366,6 +366,8 @@ def chat(self, query: str, output_type: Optional[str] = None): self._memory.add(query, True) + result_is_valid = False + try: output_type_helper = output_type_factory(output_type, logger=self.logger) viz_lib_helper = viz_lib_type_factory(self._viz_lib, logger=self.logger) @@ -464,29 +466,27 @@ def chat(self, query: str, output_type: Optional[str] = None): self._retry_run_code, code, traceback_error ) - if result is not None: - if isinstance(result, dict): - validation_ok, validation_logs = output_type_helper.validate(result) - if not validation_ok: - self.logger.log( - "\n".join(validation_logs), level=logging.WARNING - ) - self._query_exec_tracker.add_step( - { - "type": "Validating Output", - "success": False, - "message": "Output Validation Failed", - } - ) - else: - self._query_exec_tracker.add_step( - { - "type": "Validating Output", - "success": True, - "message": "Output Validation Successful", - } - ) + if isinstance(result, dict): + result_is_valid, validation_logs = output_type_helper.validate(result) + if result_is_valid: + self._query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": True, + "message": "Output Validation Successful", + } + ) + else: + self.logger.log("\n".join(validation_logs), level=logging.WARNING) + self._query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": False, + "message": "Output Validation Failed", + } + ) + if result is not None: self.last_result = result self.logger.log(f"Answer: {result}") @@ -505,7 +505,13 @@ def chat(self, query: str, output_type: Optional[str] = None): f"Executed in: {self._query_exec_tracker.get_execution_time()}s" ) - self._add_result_to_memory(result) + if result_is_valid: + self._add_result_to_memory(result) + else: + self.logger.log( + "The result will not be memorized since it has failed the " + "corresponding validation" + ) result = self._query_exec_tracker.execute_func( self._response_parser.parse, result @@ -524,9 +530,6 @@ def _add_result_to_memory(self, result: dict): Args: result (dict): The result to add to the memory """ - if result is None: - return - if result["type"] in ["string", "number"]: self._memory.add(result["value"], False) elif result["type"] in ["dataframe", "plot"]: