From 1408570865921801012d4d4b79412cb65d626dd6 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Sun, 10 Nov 2024 16:36:59 -0800 Subject: [PATCH] Allow `OBSERVABLE_INCLUDE` to target Pauli terms There are two use cases driving this change: 1. Magic state injection needing requiring partially deterministic observables 2. Simulating all observables of a code without needing to add noiseless ancilla qubits to the circuit This change allows observables to be split into pieces (e.g. obs 1 for the first half of the circuit and obs 2 for the second half, with the "true" observable their xor). The flip of each individual piece can even be recovered when using flip simulation, if stabilizer randomization is disabled. This change allows finer control over the logical labels that appear in the detector error model. For example: ``` import stim assert stim.Circuit(""" OBSERVABLE_INCLUDE(0) X0 OBSERVABLE_INCLUDE(1) Y0 OBSERVABLE_INCLUDE(2) Z0 X_ERROR(0.125) 0 Y_ERROR(0.25) 0 Z_ERROR(0.375) 0 OBSERVABLE_INCLUDE(0) X0 OBSERVABLE_INCLUDE(1) Y0 OBSERVABLE_INCLUDE(2) Z0 """).detector_error_model() == stim.DetectorErrorModel(""" error(0.375) L0 L1 error(0.25) L0 L2 error(0.125) L1 L2 """) ``` --- src/stim/circuit/circuit_instruction.cc | 14 +++- src/stim/gates/gate_data_annotations.cc | 2 +- src/stim/gates/gates.h | 3 + src/stim/mem/sparse_xor_vec.h | 8 +++ src/stim/simulators/error_analyzer.test.cc | 71 +++++++++++++++++++ src/stim/simulators/frame_simulator.inl | 15 +++- .../simulators/sparse_rev_frame_tracker.cc | 19 +++-- .../sparse_rev_frame_tracker.test.cc | 14 ++++ 8 files changed, 136 insertions(+), 10 deletions(-) diff --git a/src/stim/circuit/circuit_instruction.cc b/src/stim/circuit/circuit_instruction.cc index e762624a..e4f42c9a 100644 --- a/src/stim/circuit/circuit_instruction.cc +++ b/src/stim/circuit/circuit_instruction.cc @@ -228,9 +228,17 @@ void CircuitInstruction::validate() const { valid_target_mask |= TARGET_RECORD_BIT | TARGET_SWEEP_BIT; } if (gate.flags & GATE_ONLY_TARGETS_MEASUREMENT_RECORD) { - for (GateTarget q : targets) { - if (!(q.data & TARGET_RECORD_BIT)) { - throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes rec[-k] targets."); + if (gate.flags & GATE_TARGETS_PAULI_STRING) { + for (GateTarget q : targets) { + if (!q.is_measurement_record_target() && !q.is_pauli_target()) { + throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes measurement record targets and Pauli targets (rec[-k], Xk, Yk, Zk)."); + } + } + } else { + for (GateTarget q : targets) { + if (!q.is_measurement_record_target()) { + throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes measurement record targets (rec[-k])."); + } } } } else if (gate.flags & GATE_TARGETS_PAULI_STRING) { diff --git a/src/stim/gates/gate_data_annotations.cc b/src/stim/gates/gate_data_annotations.cc index c95027a0..2198880b 100644 --- a/src/stim/gates/gate_data_annotations.cc +++ b/src/stim/gates/gate_data_annotations.cc @@ -95,7 +95,7 @@ Parens Arguments: .id = GateType::OBSERVABLE_INCLUDE, .best_candidate_inverse_id = GateType::OBSERVABLE_INCLUDE, .arg_count = 1, - .flags = (GateFlags)(GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_IS_NOT_FUSABLE | + .flags = (GateFlags)(GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_TARGETS_PAULI_STRING | GATE_IS_NOT_FUSABLE | GATE_ARGS_ARE_UNSIGNED_INTEGERS | GATE_HAS_NO_EFFECT_ON_QUBITS), .category = "Z_Annotations", .help = R"MARKDOWN( diff --git a/src/stim/gates/gates.h b/src/stim/gates/gates.h index ad145b75..7448358a 100644 --- a/src/stim/gates/gates.h +++ b/src/stim/gates/gates.h @@ -184,8 +184,11 @@ enum GateFlags : uint16_t { // Controls validation code checking for arguments coming in pairs. GATE_TARGETS_PAIRS = 1 << 6, // Controls instructions like CORRELATED_ERROR taking Pauli product targets ("X1 Y2 Z3"). + // Note that this enables the Pauli terms but not the combine terms like X1*Y2. GATE_TARGETS_PAULI_STRING = 1 << 7, // Controls instructions like DETECTOR taking measurement record targets ("rec[-1]"). + // The "ONLY" refers to the fact that this flag switches the default behavior to not allowing qubit targets. + // Further flags can then override that default. GATE_ONLY_TARGETS_MEASUREMENT_RECORD = 1 << 8, // Controls instructions like CX operating allowing measurement record targets and sweep bit targets. GATE_CAN_TARGET_BITS = 1 << 9, diff --git a/src/stim/mem/sparse_xor_vec.h b/src/stim/mem/sparse_xor_vec.h index 787bef07..6f73330f 100644 --- a/src/stim/mem/sparse_xor_vec.h +++ b/src/stim/mem/sparse_xor_vec.h @@ -224,6 +224,14 @@ struct SparseXorVec { return sorted_items.data() + size(); } + bool operator==(const std::vector &other) const { + return sorted_items == other; + } + + bool operator!=(const std::vector &other) const { + return sorted_items != other; + } + bool operator==(const SparseXorVec &other) const { return sorted_items == other.sorted_items; } diff --git a/src/stim/simulators/error_analyzer.test.cc b/src/stim/simulators/error_analyzer.test.cc index a97beb09..88b799db 100644 --- a/src/stim/simulators/error_analyzer.test.cc +++ b/src/stim/simulators/error_analyzer.test.cc @@ -3657,3 +3657,74 @@ TEST(ErrorAnalyzer, heralded_pauli_channel_1) { )DEM"), 1e-6)); } + + +TEST(ErrorAnalyzer, OBS_INCLUDE_PAULIS) { + auto circuit = Circuit(R"CIRCUIT( + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + X_ERROR(0.125) 0 + Y_ERROR(0.25) 0 + Z_ERROR(0.375) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.375) L0 L1 + error(0.25) L0 L2 + error(0.125) L1 L2 + )DEM")); + + circuit = Circuit(R"CIRCUIT( + DEPOLARIZE1(0.125) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + X_ERROR(0.25) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + DEPOLARIZE1(0.125) 0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.25) L1 L2 + logical_observable L0 + logical_observable L0 + )DEM")); + + circuit = Circuit(R"CIRCUIT( + DEPOLARIZE1(0.125) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + Y_ERROR(0.25) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + DEPOLARIZE1(0.125) 0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.25) L0 L2 + logical_observable L1 + logical_observable L1 + )DEM")); + + circuit = Circuit(R"CIRCUIT( + DEPOLARIZE1(0.125) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + Z_ERROR(0.25) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + DEPOLARIZE1(0.125) 0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.25) L0 L1 + logical_observable L2 + logical_observable L2 + )DEM")); +} diff --git a/src/stim/simulators/frame_simulator.inl b/src/stim/simulators/frame_simulator.inl index 662362cb..3a3f5bcd 100644 --- a/src/stim/simulators/frame_simulator.inl +++ b/src/stim/simulators/frame_simulator.inl @@ -234,8 +234,19 @@ void FrameSimulator::do_OBSERVABLE_INCLUDE(const CircuitInstruction &inst) { if (keeping_detection_data) { auto r = obs_record[(size_t)inst.args[0]]; for (auto t : inst.targets) { - uint32_t lookback = t.data & TARGET_VALUE_MASK; - r ^= m_record.lookback(lookback); + if (t.is_measurement_record_target()) { + uint32_t lookback = t.data & TARGET_VALUE_MASK; + r ^= m_record.lookback(lookback); + } else if (t.is_pauli_target()) { + if (t.data & TARGET_PAULI_X_BIT) { + r ^= x_table[t.qubit_value()]; + } + if (t.data & TARGET_PAULI_Z_BIT) { + r ^= z_table[t.qubit_value()]; + } + } else { + throw std::invalid_argument("Unexpected target for OBSERVABLE_INCLUDE: " + t.str()); + } } } } diff --git a/src/stim/simulators/sparse_rev_frame_tracker.cc b/src/stim/simulators/sparse_rev_frame_tracker.cc index 1122b415..4a6bb833 100644 --- a/src/stim/simulators/sparse_rev_frame_tracker.cc +++ b/src/stim/simulators/sparse_rev_frame_tracker.cc @@ -810,11 +810,22 @@ void SparseUnsignedRevFrameTracker::undo_DETECTOR(const CircuitInstruction &dat) void SparseUnsignedRevFrameTracker::undo_OBSERVABLE_INCLUDE(const CircuitInstruction &dat) { auto obs = DemTarget::observable_id((int32_t)dat.args[0]); for (auto t : dat.targets) { - int64_t index = t.rec_offset() + (int64_t)num_measurements_in_past; - if (index < 0) { - throw std::invalid_argument("Referred to a measurement result before the beginning of time."); + if (t.is_measurement_record_target()) { + int64_t index = t.rec_offset() + (int64_t)num_measurements_in_past; + if (index < 0) { + throw std::invalid_argument("Referred to a measurement result before the beginning of time."); + } + rec_bits[index].xor_item(obs); + } else if (t.is_pauli_target()) { + if (t.data & TARGET_PAULI_X_BIT) { + xs[t.qubit_value()].xor_item(obs); + } + if (t.data & TARGET_PAULI_Z_BIT) { + zs[t.qubit_value()].xor_item(obs); + } + } else { + throw std::invalid_argument("Unexpected target for OBSERVABLE_INCLUDE: " + t.str()); } - rec_bits[index].xor_item(obs); } } diff --git a/src/stim/simulators/sparse_rev_frame_tracker.test.cc b/src/stim/simulators/sparse_rev_frame_tracker.test.cc index 58de3477..69e5548b 100644 --- a/src/stim/simulators/sparse_rev_frame_tracker.test.cc +++ b/src/stim/simulators/sparse_rev_frame_tracker.test.cc @@ -745,3 +745,17 @@ TEST(SparseUnsignedRevFrameTracker, fail_anticommute) { circuit.count_qubits(), circuit.count_measurements(), circuit.count_detectors(), true); ASSERT_THROW({ rev.undo_circuit(circuit); }, std::invalid_argument); } + +TEST(SparseUnsignedRevFrameTracker, OBS_INCLUDE_PAULIS) { + SparseUnsignedRevFrameTracker rev(4, 4, 4); + + rev.undo_circuit(Circuit("OBSERVABLE_INCLUDE(5) X1 Y2 Z3 rec[-1]")); + ASSERT_TRUE(rev.xs[0].empty()); + ASSERT_TRUE(rev.zs[0].empty()); + ASSERT_EQ(rev.xs[1], (std::vector{DemTarget::observable_id(5)})); + ASSERT_TRUE(rev.zs[1].empty()); + ASSERT_EQ(rev.xs[2], (std::vector{DemTarget::observable_id(5)})); + ASSERT_EQ(rev.zs[2], (std::vector{DemTarget::observable_id(5)})); + ASSERT_TRUE(rev.xs[3].empty()); + ASSERT_EQ(rev.zs[3], (std::vector{DemTarget::observable_id(5)})); +}