Skip to content

Commit

Permalink
refact: Use the Python SDK in the CLI implementation (#365)
Browse files Browse the repository at this point in the history
* migrate deploy cmd

* change testing technique

* migrate run

* migrate status

* removed unused module

* make the test case more robust

* simplify imports

* restore code coverage
  • Loading branch information
masci authored Nov 14, 2024
1 parent 4449d9f commit 4d6dcf7
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 122 deletions.
18 changes: 8 additions & 10 deletions llama_deploy/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import click

from llama_deploy.cli.deploy import deploy
from llama_deploy.cli.run import run
from llama_deploy.cli.status import status
from .deploy import deploy as deploy_cmd
from .run import run as run_cmd
from .status import status as status_cmd


@click.group(
Expand All @@ -21,19 +21,17 @@
@click.option(
"-t",
"--timeout",
default=None,
default=120.0,
type=float,
help="Timeout on apiserver HTTP requests",
)
@click.pass_context
def llamactl(
ctx: click.Context, server: str, insecure: bool, timeout: float | None
) -> None:
def llamactl(ctx: click.Context, server: str, insecure: bool, timeout: float) -> None:
ctx.obj = server, insecure, timeout
if ctx.invoked_subcommand is None:
click.echo(ctx.get_help()) # show the help if no subcommand was provided


llamactl.add_command(deploy)
llamactl.add_command(run)
llamactl.add_command(status)
llamactl.add_command(deploy_cmd)
llamactl.add_command(run_cmd)
llamactl.add_command(status_cmd)
21 changes: 7 additions & 14 deletions llama_deploy/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,19 @@

import click

from llama_deploy.cli.utils import request
from llama_deploy import Client


@click.command()
@click.pass_obj # global_config
@click.argument("deployment_config_file", type=click.File("rb"))
def deploy(global_config: tuple, deployment_config_file: IO) -> None:
server_url, disable_ssl, timeout = global_config
deploy_url = f"{server_url}/deployments/create"
client = Client(api_server_url=server_url, disable_ssl=disable_ssl, timeout=timeout)

files = {"config_file": deployment_config_file.read()}
resp = request(
"POST", deploy_url, files=files, verify=not disable_ssl, timeout=timeout
)
try:
deployment = client.sync.apiserver.deployments.create(deployment_config_file)
except Exception as e:
raise click.ClickException(str(e))

if resp.status_code >= 400:
try:
raise click.ClickException(resp.json().get("detail", resp.text))
except ValueError:
raise click.ClickException(resp.text)

else:
click.echo(f"Deployment successful: {resp.json().get('name')}")
click.echo(f"Deployment successful: {deployment.id}")
17 changes: 10 additions & 7 deletions llama_deploy/cli/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

import click
import httpx

from llama_deploy import Client
from llama_deploy.types import TaskDefinition


@click.command()
Expand All @@ -27,14 +29,15 @@ def run(
service: str,
) -> None:
server_url, disable_ssl, timeout = global_config
deploy_url = f"{server_url}/deployments/{deployment}/tasks/run"
client = Client(api_server_url=server_url, disable_ssl=disable_ssl, timeout=timeout)

payload = {"input": json.dumps(dict(arg))}
if service:
payload["agent_id"] = service

resp = httpx.post(deploy_url, verify=not disable_ssl, json=payload, timeout=timeout)
try:
result = client.sync.apiserver.deployments.tasks.run(TaskDefinition(**payload))
except Exception as e:
raise click.ClickException(str(e))

if resp.status_code >= 400:
raise click.ClickException(resp.json().get("detail"))
else:
click.echo(resp.json())
click.echo(result)
33 changes: 16 additions & 17 deletions llama_deploy/cli/status.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
import click


from llama_deploy.cli.utils import request
from llama_deploy import Client
from llama_deploy.types.apiserver import StatusEnum


@click.command()
@click.pass_obj # global_config
def status(global_config: tuple) -> None:
server_url, disable_ssl, timeout = global_config
status_url = f"{server_url}/status/"
client = Client(api_server_url=server_url, disable_ssl=disable_ssl, timeout=timeout)

r = request("GET", status_url, verify=not disable_ssl, timeout=timeout)
if r.status_code >= 400:
body = r.json()
click.echo(
f"Llama Deploy is unhealthy: [{r.status_code}] {r.json().get('detail')}"
)
return
try:
status = client.sync.apiserver.status()
except Exception as e:
raise click.ClickException(str(e))

click.echo("Llama Deploy is up and running.")
body = r.json()
if deployments := body.get("deployments"):
click.echo("\nActive deployments:")
for d in deployments:
click.echo(f"- {d}")
if status.status == StatusEnum.HEALTHY:
click.echo("Llama Deploy is up and running.")
if status.deployments:
click.echo("\nActive deployments:")
for d in status.deployments:
click.echo(f"- {d}")
else:
click.echo("\nCurrently there are no active deployments")
else:
click.echo("\nCurrently there are no active deployments")
click.echo(f"Llama Deploy is unhealthy: {status.status_message}")
18 changes: 0 additions & 18 deletions llama_deploy/cli/utils.py

This file was deleted.

9 changes: 8 additions & 1 deletion llama_deploy/client/models/apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,14 @@ class DeploymentCollection(Collection):
"""A model representing a collection of deployments currently active."""

async def create(self, config: TextIO) -> Deployment:
"""Creates a new deployment from a deployment file."""
"""Creates a new deployment from a deployment file.
Example:
```
with open("deployment.yml") as f:
await client.apiserver.deployments.create(f)
```
"""
create_url = f"{self.client.api_server_url}/deployments/create"

files = {"config_file": config.read()}
Expand Down
38 changes: 22 additions & 16 deletions tests/cli/test_deploy.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
from pathlib import Path
from unittest import mock

import httpx
from click.testing import CliRunner

from llama_deploy.cli import llamactl


def test_deploy(runner: CliRunner, data_path: Path) -> None:
test_config_file = data_path / "deployment.yaml"
mocked_response = mock.MagicMock(status_code=200, json=lambda: {})
with mock.patch("llama_deploy.cli.utils.httpx.request") as mocked_httpx:
mocked_httpx.return_value = mocked_response
mocked_result = mock.MagicMock(id="test_deployment")
with mock.patch("llama_deploy.cli.deploy.Client") as mocked_client:
mocked_client.return_value.sync.apiserver.deployments.create.return_value = (
mocked_result
)

result = runner.invoke(llamactl, ["-t", "5.0", "deploy", str(test_config_file)])

assert result.exit_code == 0
with open(test_config_file, "rb") as f:
mocked_httpx.assert_called_with(
"POST",
"http://localhost:4501/deployments/create",
files={"config_file": f.read()},
verify=True,
timeout=5,
)
assert result.output == "Deployment successful: test_deployment\n"
mocked_client.assert_called_with(
api_server_url="http://localhost:4501", disable_ssl=False, timeout=5.0
)
file_arg = (
mocked_client.return_value.sync.apiserver.deployments.create.call_args
)
assert str(test_config_file) == file_arg.args[0].name


def test_deploy_failed(runner: CliRunner, data_path: Path) -> None:
test_config_file = data_path / "deployment.yaml"
mocked_response = mock.MagicMock(
status_code=401, json=lambda: {"detail": "Unauthorized!"}
)
with mock.patch("llama_deploy.cli.utils.httpx.request") as mocked_httpx:
mocked_httpx.return_value = mocked_response
with mock.patch("llama_deploy.cli.deploy.Client") as mocked_client:
mocked_client.return_value.sync.apiserver.deployments.create.side_effect = (
httpx.HTTPStatusError(
"Unauthorized!", response=mock.MagicMock(), request=mock.MagicMock()
)
)

result = runner.invoke(llamactl, ["deploy", str(test_config_file)])
assert result.exit_code == 1
assert result.output == "Error: Unauthorized!\n"
66 changes: 42 additions & 24 deletions tests/cli/test_run.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,58 @@
from unittest import mock

import httpx
from click.testing import CliRunner

from llama_deploy.cli import llamactl
from llama_deploy.types import TaskDefinition


def test_run(runner: CliRunner) -> None:
mocked_response = mock.MagicMock(status_code=200, json=lambda: {})
with mock.patch("httpx.post") as mocked_post:
mocked_post.return_value = mocked_response
mocked_result = mock.MagicMock(id="test_deployment")
with mock.patch("llama_deploy.cli.run.Client") as mocked_client:
mocked_client.return_value.sync.apiserver.deployments.tasks.run.return_value = (
mocked_result
)

result = runner.invoke(
llamactl, ["run", "-d", "deployment_name", "-s", "service_name"]
llamactl,
["run", "-d", "deployment_name", "-s", "service_name"],
)
mocked_post.assert_called_with(
"http://localhost:4501/deployments/deployment_name/tasks/run",
verify=True,
json={"input": "{}", "agent_id": "service_name"},
timeout=None,

mocked_client.assert_called_with(
api_server_url="http://localhost:4501", disable_ssl=False, timeout=120.0
)

args = mocked_client.return_value.sync.apiserver.deployments.tasks.run.call_args
actual = args[0][0]
expected = TaskDefinition(agent_id="service_name", input="{}")
assert expected.input == actual.input
assert expected.agent_id == actual.agent_id
assert actual.session_id is None
assert result.exit_code == 0


def test_run_error(runner: CliRunner) -> None:
mocked_response = mock.MagicMock(
status_code=500, json=lambda: {"detail": "test error"}
)
with mock.patch("httpx.post") as mocked_post:
mocked_post.return_value = mocked_response
with mock.patch("llama_deploy.cli.run.Client") as mocked_client:
mocked_client.return_value.sync.apiserver.deployments.tasks.run.side_effect = (
httpx.HTTPStatusError(
"test error", response=mock.MagicMock(), request=mock.MagicMock()
)
)

result = runner.invoke(llamactl, ["run", "-d", "deployment_name"])

assert result.exit_code == 1
assert result.output == "Error: test error\n"


def test_run_args(runner: CliRunner) -> None:
mocked_response = mock.MagicMock(status_code=200, json=lambda: {})
with mock.patch("httpx.post") as mocked_post:
mocked_post.return_value = mocked_response
mocked_result = mock.MagicMock(id="test_deployment")
with mock.patch("llama_deploy.cli.run.Client") as mocked_client:
mocked_client.return_value.sync.apiserver.deployments.tasks.run.return_value = (
mocked_result
)

result = runner.invoke(
llamactl,
[
Expand All @@ -50,12 +67,13 @@ def test_run_args(runner: CliRunner) -> None:
'"second value with spaces"',
],
)
mocked_post.assert_called_with(
"http://localhost:4501/deployments/deployment_name/tasks/run",
verify=True,
json={
"input": '{"first_arg": "first_value", "second_arg": "\\"second value with spaces\\""}',
},
timeout=None,

args = mocked_client.return_value.sync.apiserver.deployments.tasks.run.call_args
actual = args[0][0]
expected = TaskDefinition(
input='{"first_arg": "first_value", "second_arg": "\\"second value with spaces\\""}',
)
assert expected.input == actual.input
assert expected.agent_id == actual.agent_id
assert actual.session_id is None
assert result.exit_code == 0
Loading

0 comments on commit 4d6dcf7

Please sign in to comment.