Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow OBSERVABLE_INCLUDE to target Pauli terms #853

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/stim/circuit/circuit_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,17 @@ void CircuitInstruction::validate() const {
valid_target_mask |= TARGET_RECORD_BIT | TARGET_SWEEP_BIT;
}
if (gate.flags & GATE_ONLY_TARGETS_MEASUREMENT_RECORD) {
for (GateTarget q : targets) {
if (!(q.data & TARGET_RECORD_BIT)) {
throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes rec[-k] targets.");
if (gate.flags & GATE_TARGETS_PAULI_STRING) {
for (GateTarget q : targets) {
if (!q.is_measurement_record_target() && !q.is_pauli_target()) {
throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes measurement record targets and Pauli targets (rec[-k], Xk, Yk, Zk).");
}
}
} else {
for (GateTarget q : targets) {
if (!q.is_measurement_record_target()) {
throw std::invalid_argument("Gate " + std::string(gate.name) + " only takes measurement record targets (rec[-k]).");
}
}
}
} else if (gate.flags & GATE_TARGETS_PAULI_STRING) {
Expand Down
2 changes: 1 addition & 1 deletion src/stim/gates/gate_data_annotations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Parens Arguments:
.id = GateType::OBSERVABLE_INCLUDE,
.best_candidate_inverse_id = GateType::OBSERVABLE_INCLUDE,
.arg_count = 1,
.flags = (GateFlags)(GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_IS_NOT_FUSABLE |
.flags = (GateFlags)(GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_TARGETS_PAULI_STRING | GATE_IS_NOT_FUSABLE |
GATE_ARGS_ARE_UNSIGNED_INTEGERS | GATE_HAS_NO_EFFECT_ON_QUBITS),
.category = "Z_Annotations",
.help = R"MARKDOWN(
Expand Down
3 changes: 3 additions & 0 deletions src/stim/gates/gates.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,11 @@ enum GateFlags : uint16_t {
// Controls validation code checking for arguments coming in pairs.
GATE_TARGETS_PAIRS = 1 << 6,
// Controls instructions like CORRELATED_ERROR taking Pauli product targets ("X1 Y2 Z3").
// Note that this enables the Pauli terms but not the combine terms like X1*Y2.
GATE_TARGETS_PAULI_STRING = 1 << 7,
// Controls instructions like DETECTOR taking measurement record targets ("rec[-1]").
// The "ONLY" refers to the fact that this flag switches the default behavior to not allowing qubit targets.
// Further flags can then override that default.
GATE_ONLY_TARGETS_MEASUREMENT_RECORD = 1 << 8,
// Controls instructions like CX operating allowing measurement record targets and sweep bit targets.
GATE_CAN_TARGET_BITS = 1 << 9,
Expand Down
8 changes: 8 additions & 0 deletions src/stim/mem/sparse_xor_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ struct SparseXorVec {
return sorted_items.data() + size();
}

bool operator==(const std::vector<T> &other) const {
return sorted_items == other;
}

bool operator!=(const std::vector<T> &other) const {
return sorted_items != other;
}

bool operator==(const SparseXorVec &other) const {
return sorted_items == other.sorted_items;
}
Expand Down
71 changes: 71 additions & 0 deletions src/stim/simulators/error_analyzer.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3657,3 +3657,74 @@ TEST(ErrorAnalyzer, heralded_pauli_channel_1) {
)DEM"),
1e-6));
}


TEST(ErrorAnalyzer, OBS_INCLUDE_PAULIS) {
auto circuit = Circuit(R"CIRCUIT(
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
X_ERROR(0.125) 0
Y_ERROR(0.25) 0
Z_ERROR(0.375) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
)CIRCUIT");
ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM(
error(0.375) L0 L1
error(0.25) L0 L2
error(0.125) L1 L2
)DEM"));

circuit = Circuit(R"CIRCUIT(
DEPOLARIZE1(0.125) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
X_ERROR(0.25) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
DEPOLARIZE1(0.125) 0
)CIRCUIT");
ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM(
error(0.25) L1 L2
logical_observable L0
logical_observable L0
)DEM"));

circuit = Circuit(R"CIRCUIT(
DEPOLARIZE1(0.125) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
Y_ERROR(0.25) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
DEPOLARIZE1(0.125) 0
)CIRCUIT");
ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM(
error(0.25) L0 L2
logical_observable L1
logical_observable L1
)DEM"));

circuit = Circuit(R"CIRCUIT(
DEPOLARIZE1(0.125) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
Z_ERROR(0.25) 0
OBSERVABLE_INCLUDE(0) X0
OBSERVABLE_INCLUDE(1) Y0
OBSERVABLE_INCLUDE(2) Z0
DEPOLARIZE1(0.125) 0
)CIRCUIT");
ASSERT_EQ(circuit_to_dem(circuit), DetectorErrorModel(R"DEM(
error(0.25) L0 L1
logical_observable L2
logical_observable L2
)DEM"));
}
15 changes: 13 additions & 2 deletions src/stim/simulators/frame_simulator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,19 @@ void FrameSimulator<W>::do_OBSERVABLE_INCLUDE(const CircuitInstruction &inst) {
if (keeping_detection_data) {
auto r = obs_record[(size_t)inst.args[0]];
for (auto t : inst.targets) {
uint32_t lookback = t.data & TARGET_VALUE_MASK;
r ^= m_record.lookback(lookback);
if (t.is_measurement_record_target()) {
uint32_t lookback = t.data & TARGET_VALUE_MASK;
r ^= m_record.lookback(lookback);
} else if (t.is_pauli_target()) {
if (t.data & TARGET_PAULI_X_BIT) {
r ^= x_table[t.qubit_value()];
}
if (t.data & TARGET_PAULI_Z_BIT) {
r ^= z_table[t.qubit_value()];
}
} else {
throw std::invalid_argument("Unexpected target for OBSERVABLE_INCLUDE: " + t.str());
}
}
}
}
Expand Down
19 changes: 15 additions & 4 deletions src/stim/simulators/sparse_rev_frame_tracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,22 @@ void SparseUnsignedRevFrameTracker::undo_DETECTOR(const CircuitInstruction &dat)
void SparseUnsignedRevFrameTracker::undo_OBSERVABLE_INCLUDE(const CircuitInstruction &dat) {
auto obs = DemTarget::observable_id((int32_t)dat.args[0]);
for (auto t : dat.targets) {
int64_t index = t.rec_offset() + (int64_t)num_measurements_in_past;
if (index < 0) {
throw std::invalid_argument("Referred to a measurement result before the beginning of time.");
if (t.is_measurement_record_target()) {
int64_t index = t.rec_offset() + (int64_t)num_measurements_in_past;
if (index < 0) {
throw std::invalid_argument("Referred to a measurement result before the beginning of time.");
}
rec_bits[index].xor_item(obs);
} else if (t.is_pauli_target()) {
if (t.data & TARGET_PAULI_X_BIT) {
xs[t.qubit_value()].xor_item(obs);
}
if (t.data & TARGET_PAULI_Z_BIT) {
zs[t.qubit_value()].xor_item(obs);
}
} else {
throw std::invalid_argument("Unexpected target for OBSERVABLE_INCLUDE: " + t.str());
}
rec_bits[index].xor_item(obs);
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/stim/simulators/sparse_rev_frame_tracker.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,17 @@ TEST(SparseUnsignedRevFrameTracker, fail_anticommute) {
circuit.count_qubits(), circuit.count_measurements(), circuit.count_detectors(), true);
ASSERT_THROW({ rev.undo_circuit(circuit); }, std::invalid_argument);
}

TEST(SparseUnsignedRevFrameTracker, OBS_INCLUDE_PAULIS) {
SparseUnsignedRevFrameTracker rev(4, 4, 4);

rev.undo_circuit(Circuit("OBSERVABLE_INCLUDE(5) X1 Y2 Z3 rec[-1]"));
ASSERT_TRUE(rev.xs[0].empty());
ASSERT_TRUE(rev.zs[0].empty());
ASSERT_EQ(rev.xs[1], (std::vector<DemTarget>{DemTarget::observable_id(5)}));
ASSERT_TRUE(rev.zs[1].empty());
ASSERT_EQ(rev.xs[2], (std::vector<DemTarget>{DemTarget::observable_id(5)}));
ASSERT_EQ(rev.zs[2], (std::vector<DemTarget>{DemTarget::observable_id(5)}));
ASSERT_TRUE(rev.xs[3].empty());
ASSERT_EQ(rev.zs[3], (std::vector<DemTarget>{DemTarget::observable_id(5)}));
}
Loading