Skip to content

Commit

Permalink
Add QPU information endpoints to ZMQ server
Browse files Browse the repository at this point in the history
  • Loading branch information
jfriel-oqc committed Jul 21, 2024
1 parent 2d43f85 commit 75d64c9
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 22 deletions.
89 changes: 77 additions & 12 deletions src/QAT_RPC/qat_rpc/zmq/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from enum import Enum
from time import time
from typing import Union
from importlib.metadata import version

import zmq
from qat.purr.backends.echo import get_default_echo_hardware
from qat.purr.compiler.config import CompilerConfig
from qat.purr.compiler.hardware_models import QuantumHardwareModel
from qat.purr.compiler.runtime import get_runtime
from qat.purr.integrations.features import OpenPulseFeatures
from qat.qat import execute_with_metrics

class Messages(Enum):
PROGRAM = "program"
VERSION = "version"
COUPLINGS = "couplings"
QUBIT_INFO = "qubit_info"
QPU_INFO = "qpu_info"


class ZMQBase:
def __init__(self, socket_type: zmq.SocketType):
Expand Down Expand Up @@ -65,17 +75,56 @@ def __init__(self, hardware: QuantumHardwareModel = None):
@property
def address(self):
return f"{self._protocol}://*:{self._port}"

def _program(self, program, config):
program = program
config = CompilerConfig.create_from_json(config)
result, metrics = execute_with_metrics(program, self._engine, config)
return {"results": result, "execution_metrics": metrics}

def _version(self):
return {"qat_rpc_version": str(version('qat_rpc'))}

def _couplings(self):
coupling = [coupled.direction for coupled in self._hardware.qubit_direction_couplings]
return {"couplings": coupling}

def _qubit_info(self):
raise NotImplementedError("Individual qubit information not implented, pending hardware model changes.")

def _qpu_info(self):
features = OpenPulseFeatures()
features.for_hardware(self._hardware)
qpu_info = features.to_json_dict()
return {"qpu_info": qpu_info}

def _interpret_message(self, message):
match message[0]:
case Messages.PROGRAM.value:
print(message)
if len(message) != 3:
raise ValueError(f"Program message should be of length 3, not {len(message)}")
return self._program(message[1], message[2])
case Messages.VERSION.value:
return self._version()
case Messages.COUPLINGS.value:
return self._couplings()
case Messages.QUBIT_INFO:
return self._qubit_info()
case Messages.QPU_INFO.value:
return self._qpu_info()

case _:
return self._program(message[0], message[1])


def run(self):
self._running = True
while self._running and not self._socket.closed:
msg = self._check_recieved()
if msg is not None:
try:
program = msg[0]
config = CompilerConfig.create_from_json(msg[1])
result, metrics = execute_with_metrics(program, self._engine, config)
reply = {"results": result, "execution_metrics": metrics}
reply = self._interpret_message(message=msg)
except Exception as e:
reply = {"Exception": repr(e)}
self._send(reply)
Expand All @@ -89,18 +138,34 @@ def __init__(self):
super().__init__(zmq.REQ)
self._socket.setsockopt(zmq.LINGER, 0)
self._socket.connect(self.address)

def _send(self, message):
super()._send(message=message)
return self._await_results()

def _await_results(self):
result = None
while result is None:
result = self._check_recieved()
return result

def execute_task(self, program: str, config: Union[CompilerConfig, str] = None):
self.result = None
if isinstance(config, str):
# Verify config string is valid before submitting.
config = CompilerConfig.create_from_json(config)
cfg = config or CompilerConfig()
self._send((program, cfg.to_json()))
return self._await_results()

def _await_results(self):
result = None
while result is None:
result = self._check_recieved()
return result
return self._send((Messages.PROGRAM.value, program, cfg.to_json()))

def api_version(self):
return self._send((Messages.VERSION.value,))

def qpu_couplings(self):
return self._send((Messages.COUPLINGS.value,))

def qubit_info(self):
return self._send((Messages.QUBIT_INFO.value,))

def qpu_info(self):
return self._send((Messages.QPU_INFO.value,))

65 changes: 55 additions & 10 deletions src/tests/zmq/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import threading

from importlib.metadata import version

import pytest
from qat.purr.compiler.config import CompilerConfig
from qat.purr.backends.echo import get_default_echo_hardware, add_direction_couplings_to_hardware


from qat_rpc.zmq.wrappers import ZMQClient, ZMQServer, Messages

from qat_rpc.zmq.wrappers import ZMQClient, ZMQServer
qubit_count = 8
qpu_coupligns = [(i, j) for i in range(qubit_count) for j in range(qubit_count)]


@pytest.fixture(scope="module", autouse=True)
def server():
server = ZMQServer()
hardware = get_default_echo_hardware(qubit_count=qubit_count)
hardware = add_direction_couplings_to_hardware(
hardware, qpu_coupligns
)
# server = ZMQServer()
server = ZMQServer(hardware=hardware)
server_thread = threading.Thread(target=server.run, daemon=True)
server_thread.start()

Expand All @@ -23,15 +35,7 @@ def server():
"""


def test_zmq_flow():
client = ZMQClient()

config = CompilerConfig()
config.results_format.binary_count()
config.repeats = 100

response = client.execute_task(program, config)
assert response["results"]["c"]["00"] == 100


def test_zmq_exception():
Expand Down Expand Up @@ -80,3 +84,44 @@ def test_two_zmq_clients():
thread00.join()
thread01.join()
thread10.join()

def test_program():
client = ZMQClient()

config = CompilerConfig()
config.results_format.binary_count()
config.repeats = 100

response = client.execute_task(program, config)
assert response["results"]["c"]["00"] == 100

def test_program_backwards_compatible():
client = ZMQClient()

config = CompilerConfig()
config.results_format.binary_count()
config.repeats = 100

response = client._send((program, config.to_json()))
print(response)
assert response["results"]["c"]["00"] == 100

def test_api_version():
client = ZMQClient()
api_version = client.api_version()
assert api_version["qat_rpc_version"] == version("qat_rpc")

def test_couplings():
client = ZMQClient()
couplings = client.qpu_couplings()
assert couplings["couplings"] == qpu_coupligns

def test_qubit_info():
client = ZMQClient()
qubit_info = client.qubit_info()
assert qubit_info["Exception"] is not None

def test_qpu_info():
client = ZMQClient()
qpu_info = client.qpu_info()
assert qpu_info["qpu_info"] is not None

0 comments on commit 75d64c9

Please sign in to comment.