Skip to content

Commit

Permalink
Merge pull request #2 from oqc-community/feature/ker/basic_zmq_rpc
Browse files Browse the repository at this point in the history
Feature/ker/basic zmq rpc
  • Loading branch information
keriksson-rosenqvist authored Jul 3, 2024
2 parents 3fc3261 + e2c80e3 commit 47ee731
Show file tree
Hide file tree
Showing 10 changed files with 1,884 additions and 1 deletion.
1,625 changes: 1,624 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ description = "RPC tooling for OQC QAT."
authors = ["Kajsa Eriksson Rosenqvist <[email protected]>"]
readme = "README.md"
license = "BSD-3-Clause"
packages = [
{ include = "qat_rpc", from = "src/QAT_RPC/" }
]

[tool.poetry.dependencies]
python = ">=3.8.1,<3.11"
qat-compiler = "^1.1.0"
pyzmq = "^25.1.0"

[tool.poetry.group.dev.dependencies]
coverage = "^6.3.2"
Expand All @@ -19,6 +24,9 @@ optional = true
[tool.poetry.group.licenses.dependencies]
pip-licenses = "^3.5.3"

[tool.poetry.scripts]
qat_comexe="qat_rpc.zmq.qat_commands:qat_run"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Empty file added src/QAT_RPC/qat_rpc/__init__.py
Empty file.
Empty file.
22 changes: 22 additions & 0 deletions src/QAT_RPC/qat_rpc/zmq/qat_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import argparse
from pathlib import Path

from qat_rpc.zmq.wrappers import ZMQClient


parser = argparse.ArgumentParser(prog="QAT submission service", description="Submit your QASM or QIR program to QAT.")
parser.add_argument("program", type=str, help="Program string or path to program file.")
parser.add_argument("--config", type=str, help="Serialised CompilerConfig json")


def qat_run():
args = parser.parse_args()
program = args.program
config = args.config
if Path(program).is_file():
program = Path(program).read_text()
if config is not None and Path(config).is_file():
config = Path(config).read_text()
zmq_client = ZMQClient()
results = zmq_client.execute_task(program, config)
print(results)
41 changes: 41 additions & 0 deletions src/QAT_RPC/qat_rpc/zmq/receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from pathlib import Path
from signal import SIGINT, SIGTERM, signal

from qat.purr.compiler.devices import Calibratable
from qat.purr.utils.logger import get_default_logger

from qat_rpc.zmq.wrappers import ZMQServer

log = get_default_logger()


class GracefulKill:
def __init__(self, receiver: ZMQServer):
signal(SIGINT, self._sigint)
signal(SIGTERM, self._sigterm)
self.receiver = receiver

def _sigint(self, *args):
self.receiver.stop()

def _sigterm(self, *args):
self.receiver.stop()


if __name__ == "__main__":
hw = None
if (calibration_file := os.getenv("TOSHIKO_CAL")) is not None:
calibration_file = Path(calibration_file)
if not calibration_file.is_absolute() and not calibration_file.is_file():
calibration_file = Path(Path(__file__).parent, calibration_file)
if not calibration_file.is_file():
raise ValueError(f"No such file: {calibration_file}")
log.info(f"Loading: {calibration_file} ")
hw = Calibratable.load_calibration_from_file(str(calibration_file))
log.debug("Loaded")
receiver = ZMQServer(hardware=hw)
gk = GracefulKill(receiver)

log.info(f"Starting receiver with {type(receiver._hardware)} hardware.")
receiver.run()
107 changes: 107 additions & 0 deletions src/QAT_RPC/qat_rpc/zmq/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from time import time
from typing import Union

import zmq

from qat.purr.backends.echo import get_default_echo_hardware
from qat.purr.compiler.hardware_models import QuantumHardwareModel
from qat.purr.compiler.config import CompilerConfig
from qat.purr.compiler.runtime import get_runtime
from qat.qat import execute_with_metrics


class ZMQBase:
def __init__(self, socket_type: zmq.SocketType):
self._context = zmq.Context()
self._socket = self._context.socket(socket_type)
self._timeout = 30.0
self._protocol = "tcp"
self._ip_address = "127.0.0.1"
self._port = "5556"

@property
def address(self):
return f"{self._protocol}://{self._ip_address}:{self._port}"

def _check_recieved(self):
try:
msg = self._socket.recv_pyobj(zmq.NOBLOCK)
return msg
except zmq.ZMQError:
return None

def _send(self, message) -> None:
sent = False
t0 = time()
while not sent:
try:
self._socket.send_pyobj(message, zmq.NOBLOCK)
sent = True
except zmq.ZMQError as e:
if time() > t0 + self._timeout:
raise TimeoutError(
"Sending %s on %s timedout" % (message, self.address)
)
return

def close(self):
"""Disconnect the link to the socket."""
if self._socket.closed:
return
self._socket.close()
self._context.destroy()

def __del__(self):
self.close()


class ZMQServer(ZMQBase):
def __init__(self, hardware: QuantumHardwareModel=None):
super().__init__(zmq.REP)
self._socket.bind(self.address)
self._hardware = hardware or get_default_echo_hardware(qubit_count=32)
self._engine = get_runtime(self._hardware).engine
self._running = False

@property
def address(self):
return f"{self._protocol}://*:{self._port}"

def run(self):
self._running = True
while self._running and not self._socket.closed:
msg = self._check_recieved()
if msg is not None:
try:
program = msg[0]
config = CompilerConfig.create_from_json(msg[1])
result, metrics = execute_with_metrics(program, self._engine, config)
reply = {"results": result, "execution_metrics": metrics}
except Exception as e:
reply = {"Exception": repr(e)}
self._send(reply)

def stop(self):
self._running = False


class ZMQClient(ZMQBase):
def __init__(self):
super().__init__(zmq.REQ)
self._socket.setsockopt(zmq.LINGER, 0)
self._socket.connect(self.address)

def execute_task(self, program: str, config: Union[CompilerConfig, str]=None):
self.result = None
if isinstance(config, str):
# Verify config string is valid before submitting.
config = CompilerConfig.create_from_json(config)
cfg = config or CompilerConfig()
self._send((program, cfg.to_json()))
return self._await_results()

def _await_results(self):
result = None
while result is None:
result = self._check_recieved()
return result
Empty file added src/tests/__init__.py
Empty file.
Empty file added src/tests/zmq/__init__.py
Empty file.
82 changes: 82 additions & 0 deletions src/tests/zmq/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
import threading

from qat.purr.compiler.config import CompilerConfig

from qat_rpc.zmq.wrappers import ZMQClient, ZMQServer


@pytest.fixture(scope="module", autouse=True)
def server():
server = ZMQServer()
server_thread = threading.Thread(target=server.run, daemon=True)
server_thread.start()


program = """
OPENQASM 2.0;
include "qelib1.inc";
qreg q[2];
h q;
creg c[2];
measure q->c;
"""


def test_zmq_flow():
client = ZMQClient()

config = CompilerConfig()
config.results_format.binary_count()
config.repeats = 100

response = client.execute_task(program, config)
assert response["results"]["c"]["00"] == 100


def test_zmq_exception():
client = ZMQClient()

config = CompilerConfig()
config.results_format.binary_count()
config.repeats = 100

response = client.execute_task([4, 5, 6], config)
assert response["Exception"] == "TypeError('expected string or buffer')"


def execute_and_check_result(client, program, config, result):
response = client.execute_task(program, config)
assert response["results"] == result


@pytest.mark.filterwarnings("error::pytest.PytestUnhandledThreadExceptionWarning")
def test_two_zmq_clients():
"""Verify the results are returned to the correct client."""
client0 = ZMQClient()
client1 = ZMQClient()

config0 = CompilerConfig()
config0.results_format.binary_count()
config0.repeats = 100
thread00 = threading.Thread(
target=execute_and_check_result,
args=(client0, program, config0, {"c": { "00": 100}}),
)
thread01 = threading.Thread(
target=execute_and_check_result,
args=(client0, program, config0, {"c": { "00": 100}}),
)
config1 = CompilerConfig()
config1.results_format.binary_count()
config1.repeats = 1000
thread10 = threading.Thread(
target=execute_and_check_result,
args=(client1, program, config1, {"c": { "00": 1000}}),
)
thread00.start()
thread01.start()
thread10.start()
thread00.join()
thread01.join()
thread10.join()

0 comments on commit 47ee731

Please sign in to comment.