-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for sampling from Quokka devices
- This sampler converts circuits to QASM and then sends them to a quokka endpoint for simulation. - Any parameterized circuits are resolved and sent point by point to the device.
- Loading branch information
1 parent
616ac6d
commit 558c7a8
Showing
2 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright 2024 The Unitary Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Simulation using a "Quokka" device.""" | ||
|
||
from typing import Any, Callable, Dict, Optional, Sequence | ||
|
||
import cirq | ||
import numpy as np | ||
import json | ||
|
||
_REQUEST_ENDPOINT = "http://{}.quokkacomputing.com/qsim/qasm" | ||
_DEFAULT_QUOKKA_NAME = "quokka1" | ||
|
||
JSON_TYPE = Dict[str, Any] | ||
_RESULT_KEY = "result" | ||
_ERROR_CODE_KEY = "error_code" | ||
_RESULT_KEY = "result" | ||
_SCRIPT_KEY = "script" | ||
_REPETITION_KEY = "count" | ||
|
||
|
||
class QuokkaSampler(cirq.Sampler): | ||
"""Sampler for querying a Quokka quantum simulation device. | ||
See https://www.quokkacomputing.com/ for more information.a | ||
Args: | ||
name: name of your quokka device | ||
endpoint: HTTP url endpoint to post queries to. | ||
post_function: used only for testing to override default | ||
behavior to connect to internet URLs. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str = _DEFAULT_QUOKKA_NAME, | ||
endpoint: Optional[str] = None, | ||
post_function: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, | ||
): | ||
self.quokka_name = name | ||
self.endpoint = endpoint | ||
self.post_function = post_function | ||
|
||
if self.endpoint is None: | ||
self.endpoint = _REQUEST_ENDPOINT.format(self.quokka_name) | ||
if self.post_function is None: | ||
self.post_function = self._post | ||
|
||
def _post(self, json_request: JSON_TYPE) -> JSON_TYPE: | ||
"""Sends POST queries to quokka endpoint.""" | ||
try: | ||
import requests | ||
except ImportError as e: | ||
print( | ||
"Please install requests library to use Quokka" | ||
"(e.g. pip install requests)" | ||
) | ||
raise e | ||
result = requests.post(self.endpoint, json=json_request) | ||
return json.loads(result.content) | ||
|
||
def run_sweep( | ||
self, | ||
program: "cirq.AbstractCircuit", | ||
params: "cirq.Sweepable", | ||
repetitions: int = 1, | ||
) -> Sequence["cirq.Result"]: | ||
"""Samples from the given Circuit. | ||
This allows for sweeping over different parameter values, | ||
unlike the `run` method. The `params` argument will provide a | ||
mapping from `sympy.Symbol`s used within the circuit to a set of | ||
values. Unlike the `run` method, which specifies a single | ||
mapping from symbol to value, this method allows a "sweep" of | ||
values. This allows a user to specify execution of a family of | ||
related circuits efficiently. | ||
Args: | ||
program: The circuit to sample from. | ||
params: Parameters to run with the program. | ||
repetitions: The number of times to sample. | ||
Returns: | ||
Result list for this run; one for each possible parameter resolver. | ||
""" | ||
rtn_results = [] | ||
qubits = sorted(program.all_qubits()) | ||
measure_keys = {} | ||
register_names = {} | ||
meas_i = 0 | ||
|
||
# Find all measurements in the circuit and record keys | ||
# so that we can later translate between circuit and QASM registers. | ||
for op in program.all_operations(): | ||
if isinstance(op.gate, cirq.MeasurementGate): | ||
key = cirq.measurement_key_name(op) | ||
if key in measure_keys: | ||
print( | ||
"Warning! Keys can only be measured once in Quokka simulator" | ||
) | ||
print("Key {key} will only contain the last measured value") | ||
measure_keys[key] = op.qubits | ||
if cirq.QasmOutput.valid_id_re.match(key): | ||
register_names[key] = f"m_{key}" | ||
else: | ||
register_names[key] = f"m{meas_i}" | ||
meas_i += 1 | ||
|
||
# QASM 2.0 does not support parameter sweeps, | ||
# so resolve any symbolic functions to a concrete circuit. | ||
for param_resolver in cirq.to_resolvers(params): | ||
circuit = cirq.resolve_parameters(program, param_resolver) | ||
qasm = cirq.qasm(circuit) | ||
|
||
# Hack to change sqrt-X gates into rx 0.5 gates: | ||
# Since quokka does not support sx or sxdg gates | ||
qasm = qasm.replace("\nsx ", "\nrx(pi*0.5) ").replace( | ||
"\nsxdg ", "\nrx(pi*-0.5) " | ||
) | ||
|
||
# Send data to quokka endpoint | ||
data = {_SCRIPT_KEY: qasm, _REPETITION_KEY: repetitions} | ||
json_results = self.post_function(data) | ||
|
||
if _ERROR_CODE_KEY in json_results and json_results[_ERROR_CODE_KEY] != 0: | ||
raise RuntimeError(f"Quokka returned an error: {json_results}") | ||
|
||
if _RESULT_KEY not in json_results: | ||
raise RuntimeError(f"Quokka did not return any results: {json_results}") | ||
|
||
# Associate results from json response to measurement keys. | ||
result_measurements = {} | ||
for key in measure_keys: | ||
register_name = register_names[key] | ||
if register_name not in json_results[_RESULT_KEY]: | ||
raise RuntimeError( | ||
f"Quokka did not measure key {key}: {json_results}" | ||
) | ||
result_measurements[key] = np.asarray( | ||
json_results[_RESULT_KEY][register_name], dtype=np.dtype("int8") | ||
) | ||
|
||
# Append measurements to eventual result. | ||
rtn_results.append( | ||
cirq.ResultDict( | ||
params=param_resolver, | ||
measurements=result_measurements, | ||
) | ||
) | ||
return rtn_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2024 The Unitary Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Iterable | ||
import pytest | ||
import cirq | ||
import sympy | ||
|
||
import unitary.alpha.quokka_sampler as quokka_sampler | ||
|
||
# Qubits for testing | ||
_Q = cirq.LineQubit.range(10) | ||
|
||
|
||
class FakeQuokkaEndpoint: | ||
def __init__(self, responses: Iterable[quokka_sampler.JSON_TYPE]): | ||
self.responses = list(responses) | ||
self.requests = [] | ||
|
||
def _post(self, json_request: quokka_sampler.JSON_TYPE) -> quokka_sampler.JSON_TYPE: | ||
self.requests.append(json_request) | ||
return self.responses.pop() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"circuit,json_result", | ||
[ | ||
( | ||
cirq.Circuit(cirq.X(_Q[0]), cirq.measure(_Q[0], key="mmm")), | ||
{"m_mmm": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit(cirq.X(_Q[0]), cirq.measure(_Q[0])), | ||
{"m0": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit( | ||
cirq.X(_Q[0]), cirq.X(_Q[1]), cirq.measure(_Q[0]), cirq.measure(_Q[1]) | ||
), | ||
{"m0": [[1], [1], [1], [1], [1]], "m1": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit( | ||
cirq.X(_Q[0]), | ||
cirq.CNOT(_Q[0], _Q[1]), | ||
cirq.measure(_Q[0]), | ||
cirq.measure(_Q[1]), | ||
), | ||
{"m0": [[1], [1], [1], [1], [1]], "m1": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit( | ||
cirq.X(_Q[0]), | ||
cirq.CNOT(_Q[0], _Q[1]), | ||
cirq.measure(_Q[0], _Q[1], key="m2"), | ||
), | ||
{"m_m2": [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]}, | ||
), | ||
], | ||
) | ||
def test_quokka_deterministic_examples(circuit, json_result): | ||
sim = cirq.Simulator() | ||
expected_results = sim.run(circuit, repetitions=5) | ||
json_response = {"error": "no error", "error_code": 0, "result": json_result} | ||
endpoint = FakeQuokkaEndpoint([json_response]) | ||
quokka = quokka_sampler.QuokkaSampler( | ||
name="test_mctesterface", post_function=endpoint._post | ||
) | ||
quokka_results = quokka.run(circuit, repetitions=5) | ||
assert quokka_results == expected_results | ||
|
||
|
||
def test_quokka_run_sweep(): | ||
sim = cirq.Simulator() | ||
circuit = cirq.Circuit( | ||
cirq.X(_Q[0]), | ||
cirq.X(_Q[1]) ** sympy.Symbol("X_1"), | ||
cirq.measure(_Q[0], _Q[1], key="m2"), | ||
) | ||
sweep = cirq.Points("X_1", [0, 1]) | ||
expected_results = sim.run_sweep(circuit, sweep, repetitions=5) | ||
json_response = { | ||
"error": "no error", | ||
"error_code": 0, | ||
"result": {"m_m2": [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]}, | ||
} | ||
json_response2 = { | ||
"error": "no error", | ||
"error_code": 0, | ||
"result": {"m_m2": [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]}, | ||
} | ||
endpoint = FakeQuokkaEndpoint([json_response, json_response2]) | ||
quokka = quokka_sampler.QuokkaSampler( | ||
name="test_mctesterface", post_function=endpoint._post | ||
) | ||
quokka_results = quokka.run_sweep(circuit, sweep, repetitions=5) | ||
assert quokka_results[0] == expected_results[0] |