Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Stim backend to support conditionals and mid-circuit measurements #2270

Merged
merged 6 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/tests/kernel/test_kernel_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,15 @@ def kernel(theta: float):
assert np.isclose(want_exp, -1.13, atol=1e-2)


def test_dynamic_circuit():
@pytest.mark.parametrize('target', ['default', 'stim'])
def test_dynamic_circuit(target):
"""Test that we correctly sample circuits with
mid-circuit measurements and conditionals."""

if target == 'stim':
save_target = cudaq.get_target()
cudaq.set_target('stim')

@cudaq.kernel
def simple():
q = cudaq.qvector(2)
Expand Down Expand Up @@ -297,6 +302,9 @@ def simple():
assert '0' in c0 and '1' in c0
assert '00' in counts and '11' in counts

if target == 'stim':
cudaq.set_target(save_target)


def test_teleport():

Expand Down
154 changes: 124 additions & 30 deletions runtime/nvqir/stim/StimCircuitSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,88 @@ namespace nvqir {
/// https://github.com/quantumlib/Stim.
class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
protected:
stim::Circuit stimCircuit;
// Follow Stim naming convention (W) for bit width (required for templates).
static constexpr std::size_t W = stim::MAX_BITWORD_WIDTH;

/// @brief Number of measurements performed so far.
std::size_t num_measurements = 0;

/// @brief Top-level random engine. Stim simulator RNGs are based off of this
/// engine.
std::mt19937_64 randomEngine;

/// @brief Stim Tableau simulator (noiseless)
std::unique_ptr<stim::TableauSimulator<W>> tableau;

/// @brief Stim Frame/Flip simulator (used to generate multiple shots)
std::unique_ptr<stim::FrameSimulator<W>> sampleSim;

/// @brief Grow the state vector by one qubit.
void addQubitToState() override { addQubitsToState(1); }

/// @brief Get the batch size to use for the Stim sample simulator.
std::size_t getBatchSize() {
// Default to single shot
std::size_t batch_size = 1;
if (getExecutionContext() && getExecutionContext()->name == "sample" &&
!getExecutionContext()->hasConditionalsOnMeasureResults)
batch_size = getExecutionContext()->shots;
return batch_size;
}

/// @brief Override the default sized allocation of qubits
/// here to be a bit more efficient than the default implementation
void addQubitsToState(std::size_t qubitCount,
const void *stateDataIn = nullptr) override {
if (stateDataIn)
throw std::runtime_error("The Stim simulator does not support "
"initialization of qubits from state data.");
return;

if (!tableau) {
cudaq::info("Creating new Stim Tableau simulator");
// Bump the randomEngine before cloning and giving to the Tableau
// simulator.
randomEngine.discard(
std::uniform_int_distribution<int>(1, 30)(randomEngine));
tableau = std::make_unique<stim::TableauSimulator<W>>(
std::mt19937_64(randomEngine), /*num_qubits=*/0, /*sign_bias=*/+0);
}
if (!sampleSim) {
auto batch_size = getBatchSize();
cudaq::info("Creating new Stim frame simulator with batch size {}",
batch_size);
// Bump the randomEngine before cloning and giving to the sample
// simulator.
randomEngine.discard(
std::uniform_int_distribution<int>(1, 30)(randomEngine));
sampleSim = std::make_unique<stim::FrameSimulator<W>>(
stim::CircuitStats(),
stim::FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY, batch_size,
std::mt19937_64(randomEngine));
sampleSim->reset_all();
}
}

/// @brief Reset the qubit state.
void deallocateStateImpl() override { stimCircuit.clear(); }
void deallocateStateImpl() override {
tableau.reset();
// Update the randomEngine so that future invocations will use the updated
// RNG state.
if (sampleSim)
randomEngine = std::move(sampleSim->rng);
sampleSim.reset();
num_measurements = 0;
}

/// @brief Apply operation to all Stim simulators.
void applyOpToSims(const std::string &gate_name,
const std::vector<uint32_t> &targets) {
stim::Circuit tempCircuit;
cudaq::info("Calling applyOpToSims {} - {}", gate_name, targets);
tempCircuit.safe_append_u(gate_name, targets);
tableau->safe_do_circuit(tempCircuit);
sampleSim->safe_do_circuit(tempCircuit);
}

/// @brief Apply the noise channel on \p qubits
void applyNoiseChannel(const std::string_view gateName,
Expand Down Expand Up @@ -78,19 +142,21 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
cudaq::info("Applying {} kraus channels to qubits {}", krausChannels.size(),
stimTargets);

stim::Circuit noiseOps;
for (auto &channel : krausChannels) {
if (channel.noise_type == cudaq::noise_model_type::bit_flip_channel)
stimCircuit.safe_append_ua("X_ERROR", stimTargets,
channel.parameters[0]);
noiseOps.safe_append_ua("X_ERROR", stimTargets, channel.parameters[0]);
else if (channel.noise_type ==
cudaq::noise_model_type::phase_flip_channel)
stimCircuit.safe_append_ua("Z_ERROR", stimTargets,
channel.parameters[0]);
noiseOps.safe_append_ua("Z_ERROR", stimTargets, channel.parameters[0]);
else if (channel.noise_type ==
cudaq::noise_model_type::depolarization_channel)
stimCircuit.safe_append_ua("DEPOLARIZE1", stimTargets,
channel.parameters[0]);
noiseOps.safe_append_ua("DEPOLARIZE1", stimTargets,
channel.parameters[0]);
}
// Only apply the noise operations to the sample simulator (not the Tableau
// simulator).
sampleSim->safe_do_circuit(noiseOps);
}

void applyGate(const GateApplicationTask &task) override {
Expand Down Expand Up @@ -119,7 +185,7 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
for (auto t : task.targets)
stimTargets.push_back(t);
try {
stimCircuit.safe_append_u(gateName, stimTargets);
applyOpToSims(gateName, stimTargets);
} catch (std::out_of_range &e) {
throw std::runtime_error(
fmt::format("Gate not supported by Stim simulator: {}. Note that "
Expand All @@ -137,14 +203,31 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
return 0;
}

/// @brief Measure the qubit and return the result. Collapse the
/// state vector.
bool measureQubit(const std::size_t index) override { return false; }
/// @brief Measure the qubit and return the result.
bool measureQubit(const std::size_t index) override {
// Perform measurement
applyOpToSims(
"M", std::vector<std::uint32_t>{static_cast<std::uint32_t>(index)});
num_measurements++;

// Get the tableau bit that was just generated.
const std::vector<bool> &v = tableau->measurement_record.storage;
const bool tableauBit = *v.crbegin();

// Get the mid-circuit sample to be XOR-ed with tableauBit.
bool sampleSimBit =
sampleSim->m_record.storage[num_measurements - 1][/*shot=*/0];

// Calculate the result.
bool result = tableauBit ^ sampleSimBit;

return result;
}

QubitOrdering getQubitOrdering() const override { return QubitOrdering::msb; }

public:
StimCircuitSimulator() {
StimCircuitSimulator() : randomEngine(std::random_device{}()) {
// Populate the correct name so it is printed correctly during
// deconstructor.
summaryData.name = name();
Expand All @@ -162,26 +245,38 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
void resetQubit(const std::size_t index) override {
flushGateQueue();
flushAnySamplingTasks();
stimCircuit.safe_append_u(
applyOpToSims(
"R", std::vector<std::uint32_t>{static_cast<std::uint32_t>(index)});
}

/// @brief Sample the multi-qubit state.
cudaq::ExecutionResult sample(const std::vector<std::size_t> &qubits,
const int shots) override {
assert(shots <= sampleSim->batch_size);
std::vector<std::uint32_t> stimTargetQubits(qubits.begin(), qubits.end());
stimCircuit.safe_append_u("M", stimTargetQubits);
if (false) {
std::stringstream ss;
ss << stimCircuit << '\n';
cudaq::log("Stim circuit is\n{}", ss.str());
}
auto ref_sample = stim::TableauSimulator<
stim::MAX_BITWORD_WIDTH>::reference_sample_circuit(stimCircuit);
stim::simd_bit_table<stim::MAX_BITWORD_WIDTH> sample =
stim::sample_batch_measurements(stimCircuit, ref_sample, shots,
randomEngine, false);
size_t bits_per_sample = stimCircuit.count_measurements();
applyOpToSims("M", stimTargetQubits);
num_measurements += stimTargetQubits.size();

// Generate a reference sample
const std::vector<bool> &v = tableau->measurement_record.storage;
stim::simd_bits<W> ref(v.size());
for (size_t k = 0; k < v.size(); k++)
ref[k] ^= v[k];

// Now XOR results on a per-shot basis
stim::simd_bit_table<W> sample = sampleSim->m_record.storage;
auto nShots = sampleSim->batch_size;

// This is a slightly modified version of `sample_batch_measurements`, where
// we already have the `sample` from the frame simulator. It also places the
// `sample` in a layout amenable to the order of the loops below (shot
// major).
sample = sample.transposed();
if (ref.not_zero())
for (size_t s = 0; s < nShots; s++)
sample[s].word_range_ref(0, ref.num_simd_words) ^= ref;

size_t bits_per_sample = num_measurements;
std::vector<std::string> sequentialData;
// Only retain the final "qubits.size()" measurements. All other
// measurements were mid-circuit measurements that have been previously
Expand All @@ -191,9 +286,8 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase<double> {
CountsDictionary counts;
for (std::size_t shot = 0; shot < shots; shot++) {
std::string aShot(qubits.size(), '0');
for (std::size_t b = first_bit_to_save; b < bits_per_sample; b++) {
aShot[b - first_bit_to_save] = sample[b][shot] ? '1' : '0';
}
for (std::size_t b = first_bit_to_save; b < bits_per_sample; b++)
aShot[b - first_bit_to_save] = sample[shot][b] ? '1' : '0';
counts[aShot]++;
sequentialData.push_back(std::move(aShot));
}
Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_break.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
******************************************************************************/

// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t

#include <cudaq.h>
Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_cond_for_loop-6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions targettests/execution/qir_simple_cond-1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
// RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s
// RUN: nvq++ %cpp_std --target stim --enable-mlir %s -o %t && %t | FileCheck %s
// RUN: nvq++ -std=c++17 --enable-mlir %s -o %t
// clang-format on

Expand Down
24 changes: 22 additions & 2 deletions unittests/integration/builder_tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ CUDAQ_TEST(BuilderTester, checkSwap) {

// Conditional execution on the tensornet backend is slow for a large number of
// shots.
#if !defined(CUDAQ_BACKEND_TENSORNET) && !defined(CUDAQ_BACKEND_STIM)
#if !defined(CUDAQ_BACKEND_TENSORNET)
CUDAQ_TEST(BuilderTester, checkConditional) {
{
cudaq::set_random_seed(13);
Expand Down Expand Up @@ -985,7 +985,6 @@ CUDAQ_TEST(BuilderTester, checkMidCircuitMeasure) {

EXPECT_EQ(counts.count("0", "c1"), 1000);
EXPECT_EQ(counts.count("1", "c0"), 1000);
return;
}

{
Expand All @@ -1005,6 +1004,27 @@ CUDAQ_TEST(BuilderTester, checkMidCircuitMeasure) {
EXPECT_EQ(counts.count("1", "hello2"), 0);
EXPECT_EQ(counts.count("0", "hello2"), 1000);
}

{
// Force conditional sample
auto entryPoint = cudaq::make_kernel();
auto q = entryPoint.qalloc(2);
entryPoint.h(q[0]);
auto mres = entryPoint.mz(q[0], "res0");
entryPoint.c_if(mres, [&]() { entryPoint.x(q[1]); });
entryPoint.mz(q, "final");

printf("%s\n", entryPoint.to_quake().c_str());
auto counts = cudaq::sample(entryPoint);
counts.dump();

EXPECT_GT(counts.count("0", "res0"), 0);
EXPECT_GT(counts.count("1", "res0"), 0);
EXPECT_GT(counts.count("00", "final"), 0);
EXPECT_EQ(counts.count("01", "final"), 0);
EXPECT_EQ(counts.count("10", "final"), 0);
EXPECT_GT(counts.count("11", "final"), 0);
}
}
#endif

Expand Down
Loading