Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cloud integration features to qadence #592

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ dependencies = [
"pytest-cov",
"pytest-mypy",
"pytest-xdist",
"pytest-mock",
"types-PyYAML",
"ipykernel",
"pre-commit",
Expand Down
109 changes: 109 additions & 0 deletions qadence/pasqal_cloud_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from enum import Enum

from pasqal_cloud import SDK
from pasqal_cloud import Workload as WorkloadResult

from qadence import BackendName, QuantumModel


class ResultType(Enum):
RUN = "run"
SAMPLE = "sample"
EXPECTATION = "expectation"


@dataclass(frozen=True)
class WorkloadSpec:
model: QuantumModel
result_types: list[ResultType]


class WorkloadType(Enum):
# TODO Add other workload types supported by Qadence (Pulser, Emulator)
QADENCE_CIRCUIT = "qadence_circuit"


@dataclass(frozen=True)
class WorkloadSpecJSON:
workload_type: WorkloadType
backend_type: BackendName
config: str


def workload_spec_to_json(workload: WorkloadSpec) -> WorkloadSpecJSON:
# TODO Implement this function correctly
return WorkloadSpecJSON(WorkloadType.QADENCE_CIRCUIT, BackendName.PYQTORCH, "hello world!")


def upload_workload(connection: SDK, workload: WorkloadSpec) -> str:
"""Uploads a workload to Pasqal's Cloud and returns the created workload ID."""
workload_json = workload_spec_to_json(workload)
remote_workload = connection.create_workload(
workload_json.workload_type, workload_json.backend_type, workload_json.config
)
workload_id: str = remote_workload.id
return workload_id


class WorkloadNotDoneError(Exception):
pimvenderbosch marked this conversation as resolved.
Show resolved Hide resolved
"""Is raised if a workload is not yet finished running on remote."""

pass


class WorkloadStoppedError(Exception):
"""Is raised when a workload has stopped running on remote for some reason."""

pass


def check_status(connection: SDK, workload_id: str) -> WorkloadResult:
"""Checks if the workload is succesfully finished on remote connection.
pimvenderbosch marked this conversation as resolved.
Show resolved Hide resolved

Returns the `WorkloadResult`
Raises `WorkloadNotDoneError` when the workload status is "PENDING", "RUNNING"
or "PAUSED".
Raises `WorkloadStoppedError` when the workload status is "CANCELED", "TIMED_OUT"
or "ERROR".
"""
# TODO Make the function return a "nice" result object
result = connection.get_workload(workload_id)
match result.status:
case "PENDING" | "RUNNING" | "PAUSED":
raise WorkloadNotDoneError(
f"Workload with id {workload_id} is not yet finished, the status is {result.status}"
)
case "DONE":
return result
case "CANCELED" | "TIMED_OUT" | "ERROR":
raise WorkloadStoppedError(
f"Workload with id {workload_id} couldn't finish, the status is {result.status}"
)
case _:
raise ValueError(
f"Undefined workload status ({result.status}) was returned for "
+ f"workload ({result.id})"
pimvenderbosch marked this conversation as resolved.
Show resolved Hide resolved
)


def get_result(
connection: SDK, workload_id: str, timeout: float = 60.0, refresh_time: float = 1.0
) -> WorkloadResult:
"""Repeatedly checks if a workload has finished and returns the result.

Raises `WorkloadStoppedError` when the workload has stopped running on remote
Raises `TimeoutError` when the workload is not finished after `timeout` seconds
"""
max_refresh_count = int(timeout // refresh_time)
for _ in range(max_refresh_count):
try:
result = check_status(connection, workload_id)
except WorkloadNotDoneError:
time.sleep(refresh_time)
continue
return result
raise TimeoutError("Request timed out because it wasn't finished in the specified time. ")
58 changes: 58 additions & 0 deletions tests/test_pasqal_cloud_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from typing import Any

import pytest

from qadence.pasqal_cloud_connection import (
QuantumModel,
ResultType,
WorkloadNotDoneError,
WorkloadSpec,
WorkloadStoppedError,
check_status,
upload_workload,
)


def test_upload_workload(mocker: Any, BasicQuantumModel: QuantumModel) -> None:
expected_workload_id = "my-workload"
mock_connection_return = mocker.Mock()
mock_connection_return.id = expected_workload_id
mock_connection = mocker.Mock()
mock_connection.create_workload.return_value = mock_connection_return
model = BasicQuantumModel
result_types = [ResultType.RUN, ResultType.SAMPLE]
workload = WorkloadSpec(model, result_types)
result = upload_workload(mock_connection, workload)
assert result == expected_workload_id


def test_check_status_done(mocker: Any) -> None:
mock_workload_result = mocker.Mock()
mock_workload_result.status = "DONE"
mock_connection = mocker.Mock()
mock_connection.get_workload.return_value = mock_workload_result
result = check_status(mock_connection, "my-workload")
assert result is mock_workload_result


@pytest.mark.parametrize(
"status,expected_error",
[
("PENDING", WorkloadNotDoneError),
("RUNNING", WorkloadNotDoneError),
("PAUSED", WorkloadNotDoneError),
("CANCELED", WorkloadStoppedError),
("TIMED_OUT", WorkloadStoppedError),
("ERROR", WorkloadStoppedError),
("weird-status", ValueError),
],
)
def test_check_status(mocker: Any, status: str, expected_error: Exception) -> None:
mock_workload_result = mocker.Mock()
mock_workload_result.status = status
mock_connection = mocker.Mock()
mock_connection.get_workload.return_value = mock_workload_result
with pytest.raises(expected_error):
check_status(mock_connection, "my-workload")
Loading