Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-1272: Added Agent Response, and edited unit tests #347

Open
wants to merge 7 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from aixplain.modules.agent.tool import Tool
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.agent_response import AgentResponse
from aixplain.modules.agent.agent_response_data import AgentResponseData
from aixplain.enums import ResponseStatus
from aixplain.modules.agent.utils import process_variables
from typing import Dict, List, Text, Optional, Union
from urllib.parse import urljoin
Expand Down Expand Up @@ -130,7 +133,7 @@ def run(
max_tokens: int = 2048,
max_iterations: int = 10,
output_format: OutputFormat = OutputFormat.TEXT,
) -> Dict:
) -> AgentResponse:
"""Runs an agent call.

Args:
Expand Down Expand Up @@ -163,19 +166,39 @@ def run(
max_iterations=max_iterations,
output_format=output_format,
)
if response["status"] == "FAILED":
if response["status"] == ResponseStatus.FAILED:
end = time.time()
response["elapsed_time"] = end - start
return response
poll_url = response["url"]
end = time.time()
response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
return response
result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
result_data= result.data
xainaz marked this conversation as resolved.
Show resolved Hide resolved
return AgentResponse(
xainaz marked this conversation as resolved.
Show resolved Hide resolved
status=ResponseStatus.SUCCESS,
completed=True,
data=AgentResponseData(
input=result_data.get("input"),
output=result_data.get("output"),
session_id=session_id,
xainaz marked this conversation as resolved.
Show resolved Hide resolved
),
used_credits=result_data.get("usedCredits", 0.0),
run_time=result_data.get("runTime", end - start),
)
except Exception as e:
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Agent Run: Error in running for {name}: {e}")
end = time.time()
return {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return AgentResponse(
status=ResponseStatus.FAILED,
data=AgentResponseData(
input=data,
output=None,
run_time=end - start,
session_id=session_id,
),
error=msg,
)

def run_async(
self,
Expand All @@ -189,7 +212,7 @@ def run_async(
max_tokens: int = 2048,
max_iterations: int = 10,
output_format: OutputFormat = OutputFormat.TEXT,
) -> Dict:
) -> AgentResponse:
"""Runs asynchronously an agent call.

Args:
Expand Down Expand Up @@ -257,23 +280,28 @@ def run_async(
payload.update(parameters)
payload = json.dumps(payload)

r = _request_with_retry("post", self.url, headers=headers, data=payload)
logging.info(f"Agent Run Async: Start service for {name} - {self.url} - {payload} - {headers}")

resp = None
try:
r = _request_with_retry("post", self.url, headers=headers, data=payload)
resp = r.json()
logging.info(f"Result of request for {name} - {r.status_code} - {resp}")

poll_url = resp["data"]
response = {"status": "IN_PROGRESS", "url": poll_url}
except Exception:
response = {"status": "FAILED"}
poll_url = resp.get("data")
return AgentResponse(
status=ResponseStatus.IN_PROGRESS,
url=poll_url,
data=AgentResponseData(
input=input_data
),
run_time=0.0,
used_credits=0.0,
)
except Exception as e:
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Agent Run Async: Error in running for {name}: {resp}")
if resp is not None:
response["error"] = msg
return response
logging.error(f"Agent Run Async: Error in running for {name}: {e}")
return AgentResponse(
status=ResponseStatus.FAILED,
error=msg,
)



def to_dict(self) -> Dict:
return {
Expand Down
56 changes: 56 additions & 0 deletions aixplain/modules/agent/agent_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from aixplain.enums import ResponseStatus
from typing import Any, Dict, Optional, Text, Union, List
from aixplain.modules.agent.agent_response_data import AgentResponseData
from aixplain.modules.model.response import ModelResponse

class AgentResponse(ModelResponse):

def __init__(
self,
status: ResponseStatus = ResponseStatus.FAILED,
data: Optional[AgentResponseData] = None,
details: Optional[Union[Dict, List]] = {},
completed: bool = False,
error_message: Text = "",
used_credits: float = 0.0,
run_time: float = 0.0,
usage: Optional[Dict] = None,
url: Optional[Text] = None,
**kwargs,
):

super().__init__(
status=status,
data="",
details=details,
completed=completed,
error_message=error_message,
used_credits=used_credits,
run_time=run_time,
usage=usage,
url=url,
**kwargs,
)
self.data = data or AgentResponseData()

def __getitem__(self, key: Text) -> Any:
if key == "data":
return self.data.to_dict()
return super().__getitem__(key)

def __setitem__(self, key: Text, value: Any) -> None:
if key == "data" and isinstance(value, Dict):
self.data = AgentResponseData.from_dict(value)
elif key == "data" and isinstance(value, AgentResponseData):
self.data = value
else:
super().__setitem__(key, value)

def to_dict(self) -> Dict[Text, Any]:
base_dict = super().to_dict()
base_dict["data"] = self.data.to_dict()
return base_dict

def __repr__(self) -> str:
fields = super().__repr__().strip("ModelResponse(").rstrip(")")
return f"AgentResponse({fields})"
38 changes: 38 additions & 0 deletions aixplain/modules/agent/agent_response_data.py
xainaz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import List, Dict, Any, Optional

class AgentResponseData:
def __init__(
self,
input: Optional[Any] = None,
output: Optional[Any] = None,
session_id: str = "",
intermediate_steps: Optional[List[Any]] = None,
execution_stats: Optional[Dict[str, Any]] = None,
):
self.input = input
self.output = output
self.session_id = session_id
self.intermediate_steps = intermediate_steps or []
self.execution_stats = execution_stats

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentResponseData":
return cls(
input=data.get("input"),
output=data.get("output"),
session_id=data.get("session_id", ""),
intermediate_steps=data.get("intermediate_steps", []),
execution_stats=data.get("executionStats"),
)

def to_dict(self) -> Dict[str, Any]:
return {
"input": self.input,
"output": self.output,
"session_id": self.session_id,
"intermediate_steps": self.intermediate_steps,
"executionStats": self.execution_stats,
}

def __getitem__(self, key):
return getattr(self, key, None)
16 changes: 16 additions & 0 deletions aixplain/modules/model/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,19 @@ def __contains__(self, key: Text) -> bool:
return True
except KeyError:
return False
def to_dict(self) -> Dict[Text, Any]:
base_dict = {
"status": self.status,
"data": self.data,
"details": self.details,
"completed": self.completed,
"error_message": self.error_message,
"used_credits": self.used_credits,
"run_time": self.run_time,
"usage": self.usage,
"url": self.url,
}
if self.additional_fields:
base_dict.update(self.additional_fields)
return base_dict

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 changes: 7 additions & 0 deletions tests/unit/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from urllib.parse import urljoin
import warnings
from aixplain.enums.function import Function
from aixplain.modules.agent.agent_response import AgentResponse
from aixplain.modules.agent.agent_response_data import AgentResponseData




Expand All @@ -34,6 +37,7 @@ def test_fail_query_as_text_when_content_not_empty():
data={"query": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"},
content=["https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"],
)

assert str(exc_info.value) == "When providing 'content', query must be text."


Expand Down Expand Up @@ -68,7 +72,9 @@ def test_success_query_content():
mock.post(url, headers=headers, json=ref_response)

response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"})
assert isinstance(response, AgentResponse)
assert response["status"] == ref_response["status"]
assert isinstance(response.data, AgentResponseData)
assert response["url"] == ref_response["data"]


Expand Down Expand Up @@ -310,6 +316,7 @@ def test_run_success():
response = agent.run_async(
data={"query": "Hello, how are you?"}, max_iterations=10, output_format=OutputFormat.MARKDOWN
)
assert isinstance(response, AgentResponse)
assert response["status"] == "IN_PROGRESS"
assert response["url"] == ref_response["data"]

Expand Down