From 10b155ecb089019aa872fe7888628787658a37f4 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 23 Dec 2024 16:27:52 +0300 Subject: [PATCH 1/5] Added Agent Response, and edited unit tests --- aixplain/modules/agent/__init__.py | 51 +++++++++++++---- aixplain/modules/agent/agent_response.py | 72 ++++++++++++++++++++++++ aixplain/modules/model/response.py | 16 ++++++ tests/unit/agent_test.py | 4 ++ 4 files changed, 132 insertions(+), 11 deletions(-) create mode 100644 aixplain/modules/agent/agent_response.py diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index d6d6d77d..840cb909 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -36,6 +36,8 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool +from aixplain.modules.agent.agent_response import AgentResponse +from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -130,7 +132,7 @@ def run( max_tokens: int = 2048, max_iterations: int = 10, output_format: OutputFormat = OutputFormat.TEXT, - ) -> Dict: + ) -> AgentResponse: """Runs an agent call. Args: @@ -169,13 +171,26 @@ def run( return response poll_url = response["url"] end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - return response + response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time).to_dict() + return AgentResponse( + status=ResponseStatus(response["status"]), + data=response["data"], + output=response["data"]["output"], + used_credits=response["data"]["usedCredits"], + input=response["data"]["input"], + session_id=response["data"]["session_id"], + run_time= response["data"]["runTime"] + ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Agent Run: Error in running for {name}: {e}") end = time.time() - return {"status": "FAILED", "error": msg, "elapsed_time": end - start} + return AgentResponse( + status=ResponseStatus.FAILED, + data=response["data"], + run_time= end-start, + error=msg + ) def run_async( self, @@ -189,7 +204,7 @@ def run_async( max_tokens: int = 2048, max_iterations: int = 10, output_format: OutputFormat = OutputFormat.TEXT, - ) -> Dict: + ) -> AgentResponse: """Runs asynchronously an agent call. Args: @@ -265,15 +280,29 @@ def run_async( resp = r.json() logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} + poll_url = resp.get("data") + execution_stats = resp.get("executionStats") + used_credits = resp.get("usedCredits", 0.0) + run_time = resp.get("runTime", 0.0) + + return AgentResponse( + status=ResponseStatus.IN_PROGRESS, + url=poll_url, + input=input_data, + session_id=session_id or "", + execution_stats=execution_stats, + used_credits=used_credits, + run_time=run_time, + ) except Exception: - response = {"status": "FAILED"} msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Agent Run Async: Error in running for {name}: {resp}") - if resp is not None: - response["error"] = msg - return response + return AgentResponse( + status=ResponseStatus.FAILED, + error=msg + ) + + def to_dict(self) -> Dict: return { diff --git a/aixplain/modules/agent/agent_response.py b/aixplain/modules/agent/agent_response.py new file mode 100644 index 00000000..33b48353 --- /dev/null +++ b/aixplain/modules/agent/agent_response.py @@ -0,0 +1,72 @@ +from aixplain.enums import ResponseStatus +from typing import Any, Dict, Optional, Text, Union + +class AgentResponse: + + def __init__( + self, + status: ResponseStatus = ResponseStatus.FAILED, + data: Dict={}, + input: Union[Text, Dict[str, Any]] = None, + output: Any = None, + url: Text = "", + session_id: Text = "", + run_time: float = 0.0, + used_credits: float = 0.0, + execution_stats: Optional[Dict[str, Any]] = None, + error: Optional[Text] = None, + ): + + self.status = status + self.data=data + self.input = self.validate_input(input) + self.output = self.validate_output(output) + self.url = url + self.session_id = session_id + self.run_time = run_time + self.used_credits = used_credits + self.execution_stats = execution_stats + self.error=error + + @staticmethod + def validate_input(value): + if isinstance(value, list): + return [str(row) if type(row) not in [dict, list, str, int, float, bool] else row for row in value] + elif isinstance(value, dict): + return {key: str(val) if type(val) not in [dict, list, str, int, float, bool] else val for key, val in value.items()} + elif type(value) not in [dict, list, str, int, float, bool]: + return str(value) + return value + + @staticmethod + def validate_output(value): + if isinstance(value, list): + return [str(row) if type(row) not in [dict, list, str, int, float, bool] else row for row in value] + elif isinstance(value, dict): + return {key: str(val) if type(val) not in [dict, list, str, int, float, bool] else val for key, val in value.items()} + elif type(value) not in [dict, list, str, int, float, bool]: + return str(value) + return value + + def __getitem__(self, item): + return getattr(self, item, None) + + def __setitem__(self, item, value): + if hasattr(self, item): + setattr(self, item, value) + else: + raise KeyError(f"Key '{item}' not found in AgentResponse.") + + def to_dict(self): + return { + "status": self.status, + "input": self.input, + "data": self.data, + "output": self.output, + "url":self.url, + "session_id": self.session_id, + "run_time": self.run_time, + "used_credits": self.used_credits, + "execution_stats": self.execution_stats, + "error": self.error + } \ No newline at end of file diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index 1576c1f4..bcba7a34 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -88,3 +88,19 @@ def __contains__(self, key: Text) -> bool: return True except KeyError: return False + def to_dict(self) -> Dict[Text, Any]: + base_dict = { + "status": self.status, + "data": self.data, + "details": self.details, + "completed": self.completed, + "error_message": self.error_message, + "used_credits": self.used_credits, + "run_time": self.run_time, + "usage": self.usage, + "url": self.url, + } + if self.additional_fields: + base_dict.update(self.additional_fields) + return base_dict + diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index cf217919..1eca046d 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -10,6 +10,7 @@ from urllib.parse import urljoin import warnings from aixplain.enums.function import Function +from aixplain.modules.agent.agent_response import AgentResponse @@ -34,6 +35,7 @@ def test_fail_query_as_text_when_content_not_empty(): data={"query": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"}, content=["https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"], ) + assert str(exc_info.value) == "When providing 'content', query must be text." @@ -68,6 +70,7 @@ def test_success_query_content(): mock.post(url, headers=headers, json=ref_response) response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"}) + assert isinstance(response, AgentResponse) assert response["status"] == ref_response["status"] assert response["url"] == ref_response["data"] @@ -310,6 +313,7 @@ def test_run_success(): response = agent.run_async( data={"query": "Hello, how are you?"}, max_iterations=10, output_format=OutputFormat.MARKDOWN ) + assert isinstance(response, AgentResponse) assert response["status"] == "IN_PROGRESS" assert response["url"] == ref_response["data"] From 3e90d5947a13e9fe0f64fb7e9cc25d0f29db30e3 Mon Sep 17 00:00:00 2001 From: xainaz Date: Tue, 24 Dec 2024 23:54:10 +0300 Subject: [PATCH 2/5] Made changes --- aixplain/modules/agent/__init__.py | 60 +++++----- aixplain/modules/agent/agent_response.py | 106 ++++++++---------- aixplain/modules/agent/agent_response_data.py | 46 ++++++++ tests/unit/agent_test.py | 3 + 4 files changed, 123 insertions(+), 92 deletions(-) create mode 100644 aixplain/modules/agent/agent_response_data.py diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 840cb909..9ae8aba6 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -37,6 +37,7 @@ from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.agent_response import AgentResponse +from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union @@ -165,31 +166,37 @@ def run( max_iterations=max_iterations, output_format=output_format, ) - if response["status"] == "FAILED": + if response["status"] == ResponseStatus.FAILED: end = time.time() response["elapsed_time"] = end - start return response poll_url = response["url"] end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time).to_dict() + result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) return AgentResponse( - status=ResponseStatus(response["status"]), - data=response["data"], - output=response["data"]["output"], - used_credits=response["data"]["usedCredits"], - input=response["data"]["input"], - session_id=response["data"]["session_id"], - run_time= response["data"]["runTime"] + status=ResponseStatus.SUCCESS, + data=AgentResponseData( + input=result.get("input"), + output=result.get("output"), + execution_stats=result.get("executionStats"), + run_time=result.get("runTime", end - start), + used_credits=result.get("usedCredits", 0.0), + session_id=session_id, + ), ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Agent Run: Error in running for {name}: {e}") end = time.time() return AgentResponse( - status=ResponseStatus.FAILED, - data=response["data"], - run_time= end-start, - error=msg + status=ResponseStatus.FAILED, + data=AgentResponseData( + input=data, + output=None, + run_time=end - start, + session_id=session_id, + ), + error=msg, ) def run_async( @@ -272,34 +279,25 @@ def run_async( payload.update(parameters) payload = json.dumps(payload) - r = _request_with_retry("post", self.url, headers=headers, data=payload) - logging.info(f"Agent Run Async: Start service for {name} - {self.url} - {payload} - {headers}") - - resp = None try: + r = _request_with_retry("post", self.url, headers=headers, data=payload) resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp.get("data") - execution_stats = resp.get("executionStats") - used_credits = resp.get("usedCredits", 0.0) - run_time = resp.get("runTime", 0.0) - return AgentResponse( status=ResponseStatus.IN_PROGRESS, url=poll_url, - input=input_data, - session_id=session_id or "", - execution_stats=execution_stats, - used_credits=used_credits, - run_time=run_time, + data=AgentResponseData( + input=input_data, + run_time=0.0, + used_credits=0.0, + ), ) - except Exception: + except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Agent Run Async: Error in running for {name}: {resp}") + logging.error(f"Agent Run Async: Error in running for {name}: {e}") return AgentResponse( status=ResponseStatus.FAILED, - error=msg + error=msg, ) diff --git a/aixplain/modules/agent/agent_response.py b/aixplain/modules/agent/agent_response.py index 33b48353..b4b9215d 100644 --- a/aixplain/modules/agent/agent_response.py +++ b/aixplain/modules/agent/agent_response.py @@ -1,72 +1,56 @@ from aixplain.enums import ResponseStatus -from typing import Any, Dict, Optional, Text, Union +from typing import Any, Dict, Optional, Text, Union, List +from aixplain.modules.agent.agent_response_data import AgentResponseData +from aixplain.modules.model.response import ModelResponse -class AgentResponse: +class AgentResponse(ModelResponse): def __init__( self, status: ResponseStatus = ResponseStatus.FAILED, - data: Dict={}, - input: Union[Text, Dict[str, Any]] = None, - output: Any = None, - url: Text = "", - session_id: Text = "", - run_time: float = 0.0, + data: Optional[AgentResponseData] = None, + details: Optional[Union[Dict, List]] = {}, + completed: bool = False, + error_message: Text = "", used_credits: float = 0.0, - execution_stats: Optional[Dict[str, Any]] = None, - error: Optional[Text] = None, + run_time: float = 0.0, + usage: Optional[Dict] = None, + url: Optional[Text] = None, + **kwargs, ): - self.status = status - self.data=data - self.input = self.validate_input(input) - self.output = self.validate_output(output) - self.url = url - self.session_id = session_id - self.run_time = run_time - self.used_credits = used_credits - self.execution_stats = execution_stats - self.error=error - - @staticmethod - def validate_input(value): - if isinstance(value, list): - return [str(row) if type(row) not in [dict, list, str, int, float, bool] else row for row in value] - elif isinstance(value, dict): - return {key: str(val) if type(val) not in [dict, list, str, int, float, bool] else val for key, val in value.items()} - elif type(value) not in [dict, list, str, int, float, bool]: - return str(value) - return value - - @staticmethod - def validate_output(value): - if isinstance(value, list): - return [str(row) if type(row) not in [dict, list, str, int, float, bool] else row for row in value] - elif isinstance(value, dict): - return {key: str(val) if type(val) not in [dict, list, str, int, float, bool] else val for key, val in value.items()} - elif type(value) not in [dict, list, str, int, float, bool]: - return str(value) - return value - - def __getitem__(self, item): - return getattr(self, item, None) - - def __setitem__(self, item, value): - if hasattr(self, item): - setattr(self, item, value) + super().__init__( + status=status, + data="", + details=details, + completed=completed, + error_message=error_message, + used_credits=used_credits, + run_time=run_time, + usage=usage, + url=url, + **kwargs, + ) + self.data = data or AgentResponseData() + + def __getitem__(self, key: Text) -> Any: + if key == "data": + return self.data.to_dict() + return super().__getitem__(key) + + def __setitem__(self, key: Text, value: Any) -> None: + if key == "data" and isinstance(value, Dict): + self.data = AgentResponseData.from_dict(value) + elif key == "data" and isinstance(value, AgentResponseData): + self.data = value else: - raise KeyError(f"Key '{item}' not found in AgentResponse.") + super().__setitem__(key, value) + + def to_dict(self) -> Dict[Text, Any]: + base_dict = super().to_dict() + base_dict["data"] = self.data.to_dict() + return base_dict - def to_dict(self): - return { - "status": self.status, - "input": self.input, - "data": self.data, - "output": self.output, - "url":self.url, - "session_id": self.session_id, - "run_time": self.run_time, - "used_credits": self.used_credits, - "execution_stats": self.execution_stats, - "error": self.error - } \ No newline at end of file + def __repr__(self) -> str: + fields = super().__repr__().strip("ModelResponse(").rstrip(")") + return f"AgentResponse({fields})" \ No newline at end of file diff --git a/aixplain/modules/agent/agent_response_data.py b/aixplain/modules/agent/agent_response_data.py new file mode 100644 index 00000000..2dba7427 --- /dev/null +++ b/aixplain/modules/agent/agent_response_data.py @@ -0,0 +1,46 @@ +from typing import List, Dict, Any, Optional + +class AgentResponseData: + def __init__( + self, + input: Optional[Any] = None, + output: Optional[Any] = None, + session_id: str = "", + intermediate_steps: Optional[List[Any]] = None, + run_time: float = 0.0, + used_credits: float = 0.0, + execution_stats: Optional[Dict[str, Any]] = None, + ): + self.input = input + self.output = output + self.session_id = session_id + self.intermediate_steps = intermediate_steps or [] + self.run_time = run_time + self.used_credits = used_credits + self.execution_stats = execution_stats + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AgentResponseData": + return cls( + input=data.get("input"), + output=data.get("output"), + session_id=data.get("session_id", ""), + intermediate_steps=data.get("intermediate_steps", []), + run_time=data.get("runTime", 0.0), + used_credits=data.get("usedCredits", 0.0), + execution_stats=data.get("executionStats"), + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "input": self.input, + "output": self.output, + "session_id": self.session_id, + "intermediate_steps": self.intermediate_steps, + "runTime": self.run_time, + "usedCredits": self.used_credits, + "executionStats": self.execution_stats, + } + + def __getitem__(self, key): + return getattr(self, key, None) \ No newline at end of file diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 1eca046d..a047cdc4 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -11,6 +11,8 @@ import warnings from aixplain.enums.function import Function from aixplain.modules.agent.agent_response import AgentResponse +from aixplain.modules.agent.agent_response_data import AgentResponseData + @@ -72,6 +74,7 @@ def test_success_query_content(): response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"}) assert isinstance(response, AgentResponse) assert response["status"] == ref_response["status"] + assert isinstance(response.data, AgentResponseData) assert response["url"] == ref_response["data"] From 17a57bbcf36babd7915f3eb87b21a662be5d7455 Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 3 Jan 2025 13:21:03 +0300 Subject: [PATCH 3/5] Added user credits and runtime to agent response --- aixplain/modules/agent/__init__.py | 17 ++++++----- aixplain/modules/agent/agent_response_data.py | 8 ----- test.py | 30 +++++++++++++++++++ 3 files changed, 39 insertions(+), 16 deletions(-) create mode 100644 test.py diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 9ae8aba6..00abe51e 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -173,16 +173,17 @@ def run( poll_url = response["url"] end = time.time() result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result_data= result.data return AgentResponse( status=ResponseStatus.SUCCESS, + completed=True, data=AgentResponseData( - input=result.get("input"), - output=result.get("output"), - execution_stats=result.get("executionStats"), - run_time=result.get("runTime", end - start), - used_credits=result.get("usedCredits", 0.0), + input=result_data.get("input"), + output=result_data.get("output"), session_id=session_id, ), + used_credits=result_data.get("usedCredits", 0.0), + run_time=result_data.get("runTime", end - start), ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" @@ -287,10 +288,10 @@ def run_async( status=ResponseStatus.IN_PROGRESS, url=poll_url, data=AgentResponseData( - input=input_data, - run_time=0.0, - used_credits=0.0, + input=input_data ), + run_time=0.0, + used_credits=0.0, ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" diff --git a/aixplain/modules/agent/agent_response_data.py b/aixplain/modules/agent/agent_response_data.py index 2dba7427..6baa21a0 100644 --- a/aixplain/modules/agent/agent_response_data.py +++ b/aixplain/modules/agent/agent_response_data.py @@ -7,16 +7,12 @@ def __init__( output: Optional[Any] = None, session_id: str = "", intermediate_steps: Optional[List[Any]] = None, - run_time: float = 0.0, - used_credits: float = 0.0, execution_stats: Optional[Dict[str, Any]] = None, ): self.input = input self.output = output self.session_id = session_id self.intermediate_steps = intermediate_steps or [] - self.run_time = run_time - self.used_credits = used_credits self.execution_stats = execution_stats @classmethod @@ -26,8 +22,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AgentResponseData": output=data.get("output"), session_id=data.get("session_id", ""), intermediate_steps=data.get("intermediate_steps", []), - run_time=data.get("runTime", 0.0), - used_credits=data.get("usedCredits", 0.0), execution_stats=data.get("executionStats"), ) @@ -37,8 +31,6 @@ def to_dict(self) -> Dict[str, Any]: "output": self.output, "session_id": self.session_id, "intermediate_steps": self.intermediate_steps, - "runTime": self.run_time, - "usedCredits": self.used_credits, "executionStats": self.execution_stats, } diff --git a/test.py b/test.py new file mode 100644 index 00000000..a4049f2b --- /dev/null +++ b/test.py @@ -0,0 +1,30 @@ +import os +os.environ["TEAM_API_KEY"] = "f8dcf228a8a0d2b85a800eabe8f73b9af89f571c668b7524ffe82fca83a95096" + + +Name = "agent test" +Task = "Answer the questions" + + +Tool = "640b517694bf816d35a59125" + +from aixplain.factories import AgentFactory +from aixplain.modules.agent import ModelTool + +agent = AgentFactory.create( + name=Name, + description=Task, + tools=[ + ModelTool(model=Tool), + ], + llm_id="66b2708c6eb5635d1c71f611" +) +print("agent defined") + + + + +Query = "Hello" + +agent_response = agent.run(Query) +print(vars(agent_response)) \ No newline at end of file From 62f86b83a6c01cb488ae312f281f20c148dc50de Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 3 Jan 2025 13:23:12 +0300 Subject: [PATCH 4/5] remove test --- test.py | 30 ------------------------------ 1 file changed, 30 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index a4049f2b..00000000 --- a/test.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -os.environ["TEAM_API_KEY"] = "f8dcf228a8a0d2b85a800eabe8f73b9af89f571c668b7524ffe82fca83a95096" - - -Name = "agent test" -Task = "Answer the questions" - - -Tool = "640b517694bf816d35a59125" - -from aixplain.factories import AgentFactory -from aixplain.modules.agent import ModelTool - -agent = AgentFactory.create( - name=Name, - description=Task, - tools=[ - ModelTool(model=Tool), - ], - llm_id="66b2708c6eb5635d1c71f611" -) -print("agent defined") - - - - -Query = "Hello" - -agent_response = agent.run(Query) -print(vars(agent_response)) \ No newline at end of file From 0a7ff714e2eb1ad8c66e0ef962239faecf794516 Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 3 Jan 2025 18:04:40 +0300 Subject: [PATCH 5/5] added repr, setitem, fixed format, added rest of attributes --- aixplain/modules/agent/__init__.py | 33 +++++++++---------- aixplain/modules/agent/agent_response_data.py | 21 ++++++++++-- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 00abe51e..330ff52d 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -34,8 +34,6 @@ from aixplain.modules.model import Model from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool import Tool -from aixplain.modules.agent.tool.model_tool import ModelTool -from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.enums import ResponseStatus @@ -173,17 +171,19 @@ def run( poll_url = response["url"] end = time.time() result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - result_data= result.data + result_data = result.data return AgentResponse( status=ResponseStatus.SUCCESS, completed=True, data=AgentResponseData( input=result_data.get("input"), output=result_data.get("output"), - session_id=session_id, + session_id=result_data.get("session_id"), + intermediate_steps=result_data.get("intermediateSteps"), + execution_stats=result_data.get("executionStats"), ), used_credits=result_data.get("usedCredits", 0.0), - run_time=result_data.get("runTime", end - start), + run_time=result_data.get("runTime", end - start), ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" @@ -194,8 +194,10 @@ def run( data=AgentResponseData( input=data, output=None, - run_time=end - start, + session_id=result_data.get("session_id"), session_id=session_id, + intermediate_steps=result_data.get("intermediateSteps"), + execution_stats=result_data.get("executionStats"), ), error=msg, ) @@ -287,9 +289,7 @@ def run_async( return AgentResponse( status=ResponseStatus.IN_PROGRESS, url=poll_url, - data=AgentResponseData( - input=input_data - ), + data=AgentResponseData(input=input_data), run_time=0.0, used_credits=0.0, ) @@ -301,8 +301,6 @@ def run_async( error=msg, ) - - def to_dict(self) -> Dict: return { "id": self.id, @@ -333,19 +331,19 @@ def delete(self) -> None: message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent." logging.error(message) raise Exception(f"{message}") - + def update(self) -> None: """Update agent.""" import warnings import inspect + # Get the current call stack stack = inspect.stack() - if len(stack) > 2 and stack[1].function != 'save': + if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) from aixplain.factories.agent_factory.utils import build_agent @@ -369,10 +367,9 @@ def update(self) -> None: error_msg = f"Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) - def save(self) -> None: """Save the Agent.""" - self.update() + self.update() def deploy(self) -> None: assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." diff --git a/aixplain/modules/agent/agent_response_data.py b/aixplain/modules/agent/agent_response_data.py index 6baa21a0..6a08ccfb 100644 --- a/aixplain/modules/agent/agent_response_data.py +++ b/aixplain/modules/agent/agent_response_data.py @@ -1,5 +1,6 @@ from typing import List, Dict, Any, Optional + class AgentResponseData: def __init__( self, @@ -33,6 +34,22 @@ def to_dict(self) -> Dict[str, Any]: "intermediate_steps": self.intermediate_steps, "executionStats": self.execution_stats, } - + def __getitem__(self, key): - return getattr(self, key, None) \ No newline at end of file + return getattr(self, key, None) + + def __setitem__(self, key, value): + if hasattr(self, key): + setattr(self, key, value) + else: + raise KeyError(f"{key} is not a valid attribute of {self.__class__.__name__}") + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"input={self.input}, " + f"output={self.output}, " + f"session_id='{self.session_id}', " + f"intermediate_steps={self.intermediate_steps}, " + f"execution_stats={self.execution_stats})" + )