diff --git a/src/stim/circuit/circuit_instruction.cc b/src/stim/circuit/circuit_instruction.cc index e762624a..e4f42c9a 100644 --- a/src/stim/circuit/circuit_instruction.cc +++ b/src/stim/circuit/circuit_instruction.cc @@ -228,9 +228,17 @@ void CircuitInstruction::validate() const { valid_target_mask |= TARGET_RECORD_BIT | TARGET_SWEEP_BIT; } if (gate.flags & GATE_ONLY_TARGETS_MEASUREMENT_RECORD) { - for (GateTarget q : targets) { - if (!(q.data & TARGET_RECORD_BIT)) { - throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes rec[-k] targets."); + if (gate.flags & GATE_TARGETS_PAULI_STRING) { + for (GateTarget q : targets) { + if (!q.is_measurement_record_target() && !q.is_pauli_target()) { + throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes measurement record targets and Pauli targets (rec[-k], Xk, Yk, Zk)."); + } + } + } else { + for (GateTarget q : targets) { + if (!q.is_measurement_record_target()) { + throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes measurement record targets (rec[-k])."); + } } } } else if (gate.flags & GATE_TARGETS_PAULI_STRING) { diff --git a/src/stim/gates/gate_data_annotations.cc b/src/stim/gates/gate_data_annotations.cc index c95027a0..2198880b 100644 --- a/src/stim/gates/gate_data_annotations.cc +++ b/src/stim/gates/gate_data_annotations.cc @@ -95,7 +95,7 @@ Parens Arguments: .id = GateType::OBSERVABLE_INCLUDE, .best_candidate_inverse_id = GateType::OBSERVABLE_INCLUDE, .arg_count = 1, - .flags = (GateFlags)(GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_IS_NOT_FUSABLE | + .flags = (GateFlags)(GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_TARGETS_PAULI_STRING | GATE_IS_NOT_FUSABLE | GATE_ARGS_ARE_UNSIGNED_INTEGERS | GATE_HAS_NO_EFFECT_ON_QUBITS), .category = "Z_Annotations", .help = R"MARKDOWN( diff --git a/src/stim/gates/gates.h b/src/stim/gates/gates.h index ad145b75..7448358a 100644 --- a/src/stim/gates/gates.h +++ b/src/stim/gates/gates.h @@ -184,8 +184,11 @@ enum GateFlags : uint16_t { // Controls validation code checking for arguments coming in pairs. GATE_TARGETS_PAIRS = 1 << 6, // Controls instructions like CORRELATED_ERROR taking Pauli product targets ("X1 Y2 Z3"). + // Note that this enables the Pauli terms but not the combine terms like X1*Y2. GATE_TARGETS_PAULI_STRING = 1 << 7, // Controls instructions like DETECTOR taking measurement record targets ("rec[-1]"). + // The "ONLY" refers to the fact that this flag switches the default behavior to not allowing qubit targets. + // Further flags can then override that default. GATE_ONLY_TARGETS_MEASUREMENT_RECORD = 1 << 8, // Controls instructions like CX operating allowing measurement record targets and sweep bit targets. GATE_CAN_TARGET_BITS = 1 << 9, diff --git a/src/stim/mem/sparse_xor_vec.h b/src/stim/mem/sparse_xor_vec.h index 787bef07..6f73330f 100644 --- a/src/stim/mem/sparse_xor_vec.h +++ b/src/stim/mem/sparse_xor_vec.h @@ -224,6 +224,14 @@ struct SparseXorVec { return sorted_items.data() + size(); } + bool operator==(const std::vector &other) const { + return sorted_items == other; + } + + bool operator!=(const std::vector &other) const { + return sorted_items != other; + } + bool operator==(const SparseXorVec &other) const { return sorted_items == other.sorted_items; } diff --git a/src/stim/simulators/error_analyzer.test.cc b/src/stim/simulators/error_analyzer.test.cc index a97beb09..88b799db 100644 --- a/src/stim/simulators/error_analyzer.test.cc +++ b/src/stim/simulators/error_analyzer.test.cc @@ -3657,3 +3657,74 @@ TEST(ErrorAnalyzer, heralded_pauli_channel_1) { )DEM"), 1e-6)); } + + +TEST(ErrorAnalyzer, OBS_INCLUDE_PAULIS) { + auto circuit = Circuit(R"CIRCUIT( + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + X_ERROR(0.125) 0 + Y_ERROR(0.25) 0 + Z_ERROR(0.375) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.375) L0 L1 + error(0.25) L0 L2 + error(0.125) L1 L2 + )DEM")); + + circuit = Circuit(R"CIRCUIT( + DEPOLARIZE1(0.125) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + X_ERROR(0.25) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + DEPOLARIZE1(0.125) 0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.25) L1 L2 + logical_observable L0 + logical_observable L0 + )DEM")); + + circuit = Circuit(R"CIRCUIT( + DEPOLARIZE1(0.125) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + Y_ERROR(0.25) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + DEPOLARIZE1(0.125) 0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.25) L0 L2 + logical_observable L1 + logical_observable L1 + )DEM")); + + circuit = Circuit(R"CIRCUIT( + DEPOLARIZE1(0.125) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + Z_ERROR(0.25) 0 + OBSERVABLE_INCLUDE(0) X0 + OBSERVABLE_INCLUDE(1) Y0 + OBSERVABLE_INCLUDE(2) Z0 + DEPOLARIZE1(0.125) 0 + )CIRCUIT"); + ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM( + error(0.25) L0 L1 + logical_observable L2 + logical_observable L2 + )DEM")); +} diff --git a/src/stim/simulators/frame_simulator.inl b/src/stim/simulators/frame_simulator.inl index 662362cb..3a3f5bcd 100644 --- a/src/stim/simulators/frame_simulator.inl +++ b/src/stim/simulators/frame_simulator.inl @@ -234,8 +234,19 @@ void FrameSimulator::do_OBSERVABLE_INCLUDE(const CircuitInstruction &inst) { if (keeping_detection_data) { auto r = obs_record[(size_t)inst.args[0]]; for (auto t : inst.targets) { - uint32_t lookback = t.data & TARGET_VALUE_MASK; - r ^= m_record.lookback(lookback); + if (t.is_measurement_record_target()) { + uint32_t lookback = t.data & TARGET_VALUE_MASK; + r ^= m_record.lookback(lookback); + } else if (t.is_pauli_target()) { + if (t.data & TARGET_PAULI_X_BIT) { + r ^= x_table[t.qubit_value()]; + } + if (t.data & TARGET_PAULI_Z_BIT) { + r ^= z_table[t.qubit_value()]; + } + } else { + throw std::invalid_argument("Unexpected target for OBSERVABLE_INCLUDE: " + t.str()); + } } } } diff --git a/src/stim/simulators/sparse_rev_frame_tracker.cc b/src/stim/simulators/sparse_rev_frame_tracker.cc index 1122b415..4a6bb833 100644 --- a/src/stim/simulators/sparse_rev_frame_tracker.cc +++ b/src/stim/simulators/sparse_rev_frame_tracker.cc @@ -810,11 +810,22 @@ void SparseUnsignedRevFrameTracker::undo_DETECTOR(const CircuitInstruction &dat) void SparseUnsignedRevFrameTracker::undo_OBSERVABLE_INCLUDE(const CircuitInstruction &dat) { auto obs = DemTarget::observable_id((int32_t)dat.args[0]); for (auto t : dat.targets) { - int64_t index = t.rec_offset() + (int64_t)num_measurements_in_past; - if (index < 0) { - throw std::invalid_argument("Referred to a measurement result before the beginning of time."); + if (t.is_measurement_record_target()) { + int64_t index = t.rec_offset() + (int64_t)num_measurements_in_past; + if (index < 0) { + throw std::invalid_argument("Referred to a measurement result before the beginning of time."); + } + rec_bits[index].xor_item(obs); + } else if (t.is_pauli_target()) { + if (t.data & TARGET_PAULI_X_BIT) { + xs[t.qubit_value()].xor_item(obs); + } + if (t.data & TARGET_PAULI_Z_BIT) { + zs[t.qubit_value()].xor_item(obs); + } + } else { + throw std::invalid_argument("Unexpected target for OBSERVABLE_INCLUDE: " + t.str()); } - rec_bits[index].xor_item(obs); } } diff --git a/src/stim/simulators/sparse_rev_frame_tracker.test.cc b/src/stim/simulators/sparse_rev_frame_tracker.test.cc index 58de3477..69e5548b 100644 --- a/src/stim/simulators/sparse_rev_frame_tracker.test.cc +++ b/src/stim/simulators/sparse_rev_frame_tracker.test.cc @@ -745,3 +745,17 @@ TEST(SparseUnsignedRevFrameTracker, fail_anticommute) { circuit.count_qubits(), circuit.count_measurements(), circuit.count_detectors(), true); ASSERT_THROW({ rev.undo_circuit(circuit); }, std::invalid_argument); } + +TEST(SparseUnsignedRevFrameTracker, OBS_INCLUDE_PAULIS) { + SparseUnsignedRevFrameTracker rev(4, 4, 4); + + rev.undo_circuit(Circuit("OBSERVABLE_INCLUDE(5) X1 Y2 Z3 rec[-1]")); + ASSERT_TRUE(rev.xs[0].empty()); + ASSERT_TRUE(rev.zs[0].empty()); + ASSERT_EQ(rev.xs[1], (std::vector{DemTarget::observable_id(5)})); + ASSERT_TRUE(rev.zs[1].empty()); + ASSERT_EQ(rev.xs[2], (std::vector{DemTarget::observable_id(5)})); + ASSERT_EQ(rev.zs[2], (std::vector{DemTarget::observable_id(5)})); + ASSERT_TRUE(rev.xs[3].empty()); + ASSERT_EQ(rev.zs[3], (std::vector{DemTarget::observable_id(5)})); +}