diff --git a/Cargo.toml b/Cargo.toml index a35a87d4d3b1..bfacfcfc08b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,8 @@ license = "Apache-2.0" # # Each crate can add on specific features freely as it inherits. [workspace.dependencies] -bytemuck = "1.17" -indexmap.version = "2.4.0" +bytemuck = "1.16" +indexmap.version = "2.3.0" hashbrown.version = "0.14.5" num-bigint = "0.4" num-complex = "0.4" diff --git a/crates/circuit/Cargo.toml b/crates/circuit/Cargo.toml index ed1f849bbf62..16e83658f333 100644 --- a/crates/circuit/Cargo.toml +++ b/crates/circuit/Cargo.toml @@ -10,7 +10,7 @@ name = "qiskit_circuit" doctest = false [dependencies] -rayon.workspace = true +rayon = "1.10" ahash.workspace = true rustworkx-core.workspace = true bytemuck.workspace = true diff --git a/crates/circuit/src/bit_data.rs b/crates/circuit/src/bit_data.rs index 0c0b20a02522..4af7ef083fc6 100644 --- a/crates/circuit/src/bit_data.rs +++ b/crates/circuit/src/bit_data.rs @@ -17,6 +17,7 @@ use pyo3::prelude::*; use pyo3::types::PyList; use std::fmt::Debug; use std::hash::{Hash, Hasher}; +use std::mem::swap; /// Private wrapper for Python-side Bit instances that implements /// [Hash] and [Eq], allowing them to be used in Rust hash-based @@ -220,3 +221,37 @@ where self.bits.clear(); } } + +pub struct Iter<'a, T> { + _data: &'a BitData, + index: usize, +} + +impl<'a, T> Iterator for Iter<'a, T> +where + T: From, +{ + type Item = T; + + fn next(&mut self) -> Option { + let mut index = self.index + 1; + swap(&mut self.index, &mut index); + let index: Option = index.try_into().ok(); + index.map(|i| From::from(i)) + } +} + +impl<'a, T> IntoIterator for &'a BitData +where + T: From, +{ + type Item = T; + type IntoIter = Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + Iter { + _data: self, + index: 0, + } + } +} diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index bb5fff5e343a..8c754d2e6e91 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -29,6 +29,7 @@ use crate::rustworkx_core_vnext::isomorphism; use crate::{BitType, Clbit, Qubit, TupleLikeArg}; use hashbrown::{HashMap, HashSet}; +use indexmap::map::Entry; use indexmap::IndexMap; use itertools::Itertools; @@ -39,6 +40,7 @@ use pyo3::types::{ IntoPyDict, PyDict, PyInt, PyIterator, PyList, PySequence, PySet, PyString, PyTuple, PyType, }; +use rustworkx_core::connectivity::connected_components as core_connected_components; use rustworkx_core::dag_algo::layers; use rustworkx_core::err::ContractError; use rustworkx_core::graph_ext::ContractNodesDirected; @@ -48,7 +50,7 @@ use rustworkx_core::petgraph::prelude::*; use rustworkx_core::petgraph::stable_graph::{EdgeReference, NodeIndex}; use rustworkx_core::petgraph::unionfind::UnionFind; use rustworkx_core::petgraph::visit::{ - EdgeIndexable, IntoEdgeReferences, IntoNodeReferences, NodeFiltered, NodeIndexable, + EdgeIndexable, IntoEdgeReferences, IntoNodeReferences, NodeIndexable, }; use rustworkx_core::petgraph::Incoming; use rustworkx_core::traversal::{ @@ -62,13 +64,13 @@ use std::convert::Infallible; use std::f64::consts::PI; #[cfg(feature = "cache_pygates")] -use std::cell::OnceCell; +use std::cell::RefCell; static CONTROL_FLOW_OP_NAMES: [&str; 4] = ["for_loop", "while_loop", "if_else", "switch_case"]; static SEMANTIC_EQ_SYMMETRIC: [&str; 4] = ["barrier", "swap", "break_loop", "continue_loop"]; #[derive(Clone, Debug)] -pub enum NodeType { +pub(crate) enum NodeType { QubitIn(Qubit), QubitOut(Qubit), ClbitIn(Clbit), @@ -79,7 +81,7 @@ pub enum NodeType { } #[derive(Clone, Debug)] -pub enum Wire { +pub(crate) enum Wire { Qubit(Qubit), Clbit(Clbit), Var(PyObject), @@ -98,8 +100,6 @@ impl PartialEq for Wire { } } -impl Eq for Wire {} - impl Hash for Wire { fn hash(&self, state: &mut H) { match self { @@ -110,16 +110,28 @@ impl Hash for Wire { } } -impl Wire { - fn to_pickle(&self, py: Python) -> PyObject { +impl IntoPy for Wire { + fn into_py(self, py: Python) -> PyObject { match self { Self::Qubit(bit) => (0, bit.0.into_py(py)).into_py(py), Self::Clbit(bit) => (1, bit.0.into_py(py)).into_py(py), Self::Var(var) => (2, var.clone_ref(py)).into_py(py), } } +} - fn from_pickle(b: &Bound) -> PyResult { +impl ToPyObject for Wire { + fn to_object(&self, py: Python) -> PyObject { + match self { + Self::Qubit(bit) => (0, bit.0.into_py(py)).to_object(py), + Self::Clbit(bit) => (1, bit.0.into_py(py)).to_object(py), + Self::Var(var) => (2, var.clone_ref(py)).to_object(py), + } + } +} + +impl<'py> FromPyObject<'py> for Wire { + fn extract_bound(b: &Bound<'py, PyAny>) -> Result { let tuple: Bound = b.extract()?; let wire_type: usize = tuple.get_item(0)?.extract()?; if wire_type == 0 { @@ -134,6 +146,8 @@ impl Wire { } } +impl Eq for Wire {} + // TODO: Remove me. // This is a temporary map type used to store a mapping of // Var to NodeIndex to hold us over until Var is ported to @@ -154,55 +168,56 @@ impl _VarIndexMap { } } - pub fn keys(&self, py: Python) -> impl Iterator { - self.dict - .bind(py) - .keys() - .into_iter() - .map(|k| k.unbind()) - .collect::>() - .into_iter() + pub fn keys(&self) -> impl Iterator { + Python::with_gil(|py| { + self.dict + .bind(py) + .keys() + .into_iter() + .map(|k| k.unbind()) + .collect::>() + .into_iter() + }) } - pub fn contains_key(&self, py: Python, key: &PyObject) -> bool { - self.dict.bind(py).contains(key).unwrap() + pub fn contains_key(&self, key: &PyObject) -> bool { + Python::with_gil(|py| self.dict.bind(py).contains(key).unwrap()) } - pub fn get(&self, py: Python, key: &PyObject) -> Option { - self.dict - .bind(py) - .get_item(key) - .unwrap() - .map(|v| NodeIndex::new(v.extract().unwrap())) + pub fn get(&self, key: &PyObject) -> Option { + Python::with_gil(|py| { + self.dict + .bind(py) + .get_item(key) + .unwrap() + .map(|v| NodeIndex::new(v.extract().unwrap())) + }) } - pub fn insert(&mut self, py: Python, key: PyObject, value: NodeIndex) { - self.dict - .bind(py) - .set_item(key, value.index().into_py(py)) - .unwrap() + pub fn insert(&mut self, key: PyObject, value: NodeIndex) { + Python::with_gil(|py| { + self.dict + .bind(py) + .set_item(key, value.index().into_py(py)) + .unwrap() + }) } - pub fn remove(&mut self, py: Python, key: &PyObject) -> Option { - let bound_dict = self.dict.bind(py); - let res = bound_dict - .get_item(key.clone_ref(py)) - .unwrap() - .map(|v| NodeIndex::new(v.extract().unwrap())); - let _del_result = bound_dict.del_item(key); - res + pub fn remove(&mut self, key: &PyObject) -> Option { + Python::with_gil(|py| -> Option { + let bound_dict = self.dict.bind(py); + let res = bound_dict + .get_item(key.clone_ref(py)) + .unwrap() + .map(|v| NodeIndex::new(v.extract().unwrap())); + let _del_result = bound_dict.del_item(key); + res + }) } pub fn values<'py>(&self, py: Python<'py>) -> impl Iterator + 'py { let values = self.dict.bind(py).values(); values.iter().map(|x| NodeIndex::new(x.extract().unwrap())) } - - pub fn iter<'py>(&self, py: Python<'py>) -> impl Iterator + 'py { - self.dict - .bind(py) - .iter() - .map(|(var, index)| (var.unbind(), NodeIndex::new(index.extract().unwrap()))) - } } /// Quantum circuit as a directed acyclic graph. @@ -223,7 +238,7 @@ pub struct DAGCircuit { calibrations: HashMap>, - pub dag: StableDiGraph, + pub(crate) dag: StableDiGraph, #[pyo3(get)] qregs: Py, @@ -235,9 +250,9 @@ pub struct DAGCircuit { /// The cache used to intern instruction cargs. cargs_cache: IndexedInterner>, /// Qubits registered in the circuit. - pub qubits: BitData, + pub(crate) qubits: BitData, /// Clbits registered in the circuit. - pub clbits: BitData, + pub(crate) clbits: BitData, /// Global phase. global_phase: Param, /// Duration. @@ -256,11 +271,15 @@ pub struct DAGCircuit { qubit_locations: Py, clbit_locations: Py, - /// Map from qubit to input and output nodes of the graph. - qubit_io_map: Vec<[NodeIndex; 2]>, + /// Map from qubit to input nodes of the graph. + qubit_input_map: IndexMap, + /// Map from qubit to output nodes of the graph. + qubit_output_map: IndexMap, - /// Map from clbit to input and output nodes of the graph. - clbit_io_map: Vec<[NodeIndex; 2]>, + /// Map from clbit to input nodes of the graph. + clbit_input_map: IndexMap, + /// Map from clbit to output nodes of the graph. + clbit_output_map: IndexMap, // TODO: use IndexMap once Var is ported to Rust /// Map from var to input nodes of the graph. @@ -364,11 +383,13 @@ impl PyVariableMapper { .bind(py) .call_method1(intern!(py, "map_target"), (target,)) } -} -impl IntoPy> for PyVariableMapper { - fn into_py(self, _py: Python<'_>) -> Py { + #[allow(dead_code)] + fn map_expr<'py>(&self, node: &Bound<'py, PyAny>) -> PyResult> { + let py = node.py(); self.mapper + .bind(py) + .call_method1(intern!(py, "map_expr"), (node,)) } } @@ -380,15 +401,27 @@ fn reject_new_register(reg: &Bound) -> PyResult<()> { ))) } +impl IntoPy> for PyVariableMapper { + fn into_py(self, _py: Python<'_>) -> Py { + self.mapper + } +} + #[pyclass(module = "qiskit._accelerate.circuit")] #[derive(Clone, Debug)] struct BitLocations { #[pyo3(get)] - index: usize, + pub index: usize, #[pyo3(get)] registers: Py, } +impl BitLocations { + fn set_index(&mut self, index: usize) { + self.index = index + } +} + #[derive(Copy, Clone, Debug)] enum DAGVarType { Input = 0, @@ -424,8 +457,10 @@ impl DAGCircuit { unit: "dt".to_string(), qubit_locations: PyDict::new_bound(py).unbind(), clbit_locations: PyDict::new_bound(py).unbind(), - qubit_io_map: Vec::new(), - clbit_io_map: Vec::new(), + qubit_input_map: IndexMap::default(), + qubit_output_map: IndexMap::default(), + clbit_input_map: IndexMap::default(), + clbit_output_map: IndexMap::default(), var_input_map: _VarIndexMap::new(py), var_output_map: _VarIndexMap::new(py), op_names: IndexMap::default(), @@ -442,26 +477,16 @@ impl DAGCircuit { #[getter] fn input_map(&self, py: Python) -> PyResult> { let out_dict = PyDict::new_bound(py); - for (qubit, indices) in self - .qubit_io_map - .iter() - .enumerate() - .map(|(idx, indices)| (Qubit(idx as u32), indices)) - { + for (qubit, index) in self.qubit_input_map.iter() { out_dict.set_item( - self.qubits.get(qubit).unwrap().clone_ref(py), - self.get_node(py, indices[0])?, + self.qubits.get(*qubit).unwrap().clone_ref(py), + self.get_node(py, *index)?, )?; } - for (clbit, indices) in self - .clbit_io_map - .iter() - .enumerate() - .map(|(idx, indices)| (Clbit(idx as u32), indices)) - { + for (clbit, index) in self.clbit_input_map.iter() { out_dict.set_item( - self.clbits.get(clbit).unwrap().clone_ref(py), - self.get_node(py, indices[0])?, + self.clbits.get(*clbit).unwrap().clone_ref(py), + self.get_node(py, *index)?, )?; } for (var, index) in self.var_input_map.dict.bind(py).iter() { @@ -476,26 +501,16 @@ impl DAGCircuit { #[getter] fn output_map(&self, py: Python) -> PyResult> { let out_dict = PyDict::new_bound(py); - for (qubit, indices) in self - .qubit_io_map - .iter() - .enumerate() - .map(|(idx, indices)| (Qubit(idx as u32), indices)) - { + for (qubit, index) in self.qubit_output_map.iter() { out_dict.set_item( - self.qubits.get(qubit).unwrap().clone_ref(py), - self.get_node(py, indices[1])?, + self.qubits.get(*qubit).unwrap().clone_ref(py), + self.get_node(py, *index)?, )?; } - for (clbit, indices) in self - .clbit_io_map - .iter() - .enumerate() - .map(|(idx, indices)| (Clbit(idx as u32), indices)) - { + for (clbit, index) in self.clbit_output_map.iter() { out_dict.set_item( - self.clbits.get(clbit).unwrap().clone_ref(py), - self.get_node(py, indices[1])?, + self.clbits.get(*clbit).unwrap().clone_ref(py), + self.get_node(py, *index)?, )?; } for (var, index) in self.var_output_map.dict.bind(py).iter() { @@ -516,19 +531,31 @@ impl DAGCircuit { out_dict.set_item("cregs", self.cregs.clone_ref(py))?; out_dict.set_item("global_phase", self.global_phase.clone())?; out_dict.set_item( - "qubit_io_map", - self.qubit_io_map + "qubit_input_map", + self.qubit_input_map .iter() - .enumerate() - .map(|(k, v)| (k, [v[0].index(), v[1].index()])) + .map(|(k, v)| (k.0, v.index())) .collect::>(), )?; out_dict.set_item( - "clbit_io_map", - self.clbit_io_map + "clbit_input_map", + self.clbit_input_map .iter() - .enumerate() - .map(|(k, v)| (k, [v[0].index(), v[1].index()])) + .map(|(k, v)| (k.0, v.index())) + .collect::>(), + )?; + out_dict.set_item( + "qubit_output_map", + self.qubit_output_map + .iter() + .map(|(k, v)| (k.0, v.index())) + .collect::>(), + )?; + out_dict.set_item( + "clbit_output_map", + self.clbit_output_map + .iter() + .map(|(k, v)| (k.0, v.index())) .collect::>(), )?; out_dict.set_item("var_input_map", self.var_input_map.dict.clone_ref(py))?; @@ -571,12 +598,7 @@ impl DAGCircuit { let edge = match self.dag.edge_weight(idx) { Some(edge_w) => { let endpoints = self.dag.edge_endpoints(idx).unwrap(); - ( - endpoints.0.index(), - endpoints.1.index(), - edge_w.clone().to_pickle(py), - ) - .to_object(py) + (endpoints.0.index(), endpoints.1.index(), edge_w.clone()).to_object(py) } None => py.None(), }; @@ -631,22 +653,37 @@ impl DAGCircuit { for bit in clbits_raw.iter() { self.clbits.add(py, &bit, false)?; } - let binding = dict_state.get_item("qubit_io_map")?.unwrap(); + let binding = dict_state.get_item("qubit_input_map")?.unwrap(); let qubit_index_map_raw = binding.downcast::().unwrap(); - self.qubit_io_map = Vec::with_capacity(qubit_index_map_raw.len()); - for (_k, v) in qubit_index_map_raw.iter() { - let indices: [usize; 2] = v.extract()?; - self.qubit_io_map - .push([NodeIndex::new(indices[0]), NodeIndex::new(indices[1])]); + self.qubit_input_map = + IndexMap::with_capacity_and_hasher(qubit_index_map_raw.len(), RandomState::default()); + for (k, v) in qubit_index_map_raw.iter() { + self.qubit_input_map + .insert(Qubit(k.extract()?), NodeIndex::new(v.extract()?)); } - let binding = dict_state.get_item("clbit_io_map")?.unwrap(); + let binding = dict_state.get_item("clbit_input_map")?.unwrap(); let clbit_index_map_raw = binding.downcast::().unwrap(); - self.clbit_io_map = Vec::with_capacity(clbit_index_map_raw.len()); - - for (_k, v) in clbit_index_map_raw.iter() { - let indices: [usize; 2] = v.extract()?; - self.clbit_io_map - .push([NodeIndex::new(indices[0]), NodeIndex::new(indices[1])]); + self.clbit_input_map = + IndexMap::with_capacity_and_hasher(qubit_index_map_raw.len(), RandomState::default()); + for (k, v) in clbit_index_map_raw.iter() { + self.clbit_input_map + .insert(Clbit(k.extract()?), NodeIndex::new(v.extract()?)); + } + let binding = dict_state.get_item("qubit_output_map")?.unwrap(); + let qubit_index_map_raw = binding.downcast::().unwrap(); + self.qubit_output_map = + IndexMap::with_capacity_and_hasher(qubit_index_map_raw.len(), RandomState::default()); + for (k, v) in qubit_index_map_raw.iter() { + self.qubit_output_map + .insert(Qubit(k.extract()?), NodeIndex::new(v.extract()?)); + } + let binding = dict_state.get_item("clbit_output_map")?.unwrap(); + let clbit_index_map_raw = binding.downcast::().unwrap(); + self.clbit_input_map = + IndexMap::with_capacity_and_hasher(qubit_index_map_raw.len(), RandomState::default()); + for (k, v) in clbit_index_map_raw.iter() { + self.clbit_output_map + .insert(Clbit(k.extract()?), NodeIndex::new(v.extract()?)); } // Rebuild Graph preserving index holes: let binding = dict_state.get_item("nodes")?.unwrap(); @@ -716,7 +753,7 @@ impl DAGCircuit { let triple = item.downcast::().unwrap(); let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); - let edge_w = Wire::from_pickle(&triple.get_item(2).unwrap())?; + let edge_w = triple.get_item(2).unwrap().extract::()?; self.dag .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); } @@ -725,6 +762,10 @@ impl DAGCircuit { Ok(()) } + pub fn _get_node(&self, py: Python, index: usize) -> PyResult { + self.get_node(py, NodeIndex::new(index)) + } + /// Returns the current sequence of registered :class:`.Qubit` instances as a list. /// /// .. warning:: @@ -756,20 +797,14 @@ impl DAGCircuit { /// Return a list of the wires in order. #[getter] - fn get_wires(&self, py: Python<'_>) -> PyResult> { + fn get_wires(&self, py: Python<'_>) -> Py { let wires: Vec<&PyObject> = self .qubits .bits() .iter() .chain(self.clbits.bits().iter()) .collect(); - let out_list = PyList::new_bound(py, wires); - for var_type_set in &self.vars_by_type { - for var in var_type_set.bind(py).iter() { - out_list.append(var)?; - } - } - Ok(out_list.unbind()) + PyList::new_bound(py, wires).unbind() } /// Returns the number of nodes in the dag. @@ -797,7 +832,7 @@ impl DAGCircuit { Param::ParameterExpression(angle) => { self.global_phase = Param::ParameterExpression(angle); } - Param::Obj(_) => return Err(PyTypeError::new_err("Invalid type for global phase")), + Param::Obj(_) => return Err(PyValueError::new_err("Invalid type for global phase")), } Ok(()) } @@ -964,8 +999,9 @@ def _format(operand): } /// Add individual qubit wires. - fn add_clbits(&mut self, py: Python, clbits: Vec>) -> PyResult<()> { - for bit in clbits.iter() { + fn add_clbits(&mut self, py: Python, clbits: &Bound) -> PyResult<()> { + let bits: Vec> = clbits.extract()?; + for bit in bits.iter() { if !bit.is_instance(imports::CLBIT.get_bound(py))? { return Err(DAGCircuitError::new_err("not a Clbit instance.")); } @@ -978,7 +1014,7 @@ def _format(operand): } } - for bit in clbits.iter() { + for bit in bits.iter() { self.add_clbit_unchecked(py, bit)?; } Ok(()) @@ -1131,7 +1167,7 @@ def _format(operand): let clbits: HashSet = bit_iter.collect(); let mut busy_bits = Vec::new(); for bit in clbits.iter() { - if !self.is_wire_idle(py, &Wire::Clbit(*bit))? { + if !self.is_wire_idle(&Wire::Clbit(*bit))? { busy_bits.push(self.clbits.get(*bit).unwrap()); } } @@ -1158,7 +1194,7 @@ def _format(operand): // Remove DAG in/out nodes etc. for bit in clbits.iter() { - self.remove_idle_wire(py, Wire::Clbit(*bit))?; + self.remove_idle_wire(Wire::Clbit(*bit))?; } // Copy the current clbit mapping so we can use it while remapping @@ -1167,32 +1203,31 @@ def _format(operand): // Remove the clbit indices, which will invalidate our mapping of Clbit to // Python bits throughout the entire DAG. - self.clbits.remove_indices(py, clbits.clone())?; + self.clbits.remove_indices(py, clbits)?; // Update input/output maps to use new Clbits. - let io_mapping: HashMap = self - .clbit_io_map + self.clbit_input_map = self + .clbit_input_map .drain(..) - .enumerate() - .filter_map(|(k, v)| { - let clbit = Clbit(k as u32); - if clbits.contains(&clbit) { - None - } else { - Some(( - self.clbits - .find(old_clbits.get(Clbit(k as u32)).unwrap().bind(py)) - .unwrap(), - v, - )) - } + .map(|(k, v)| { + ( + self.clbits + .find(old_clbits.get(k).unwrap().bind(py)) + .unwrap(), + v, + ) }) .collect(); - - self.clbit_io_map = (0..io_mapping.len()) - .map(|idx| { - let clbit = Clbit(idx as u32); - io_mapping[&clbit] + self.clbit_output_map = self + .clbit_output_map + .drain(..) + .map(|(k, v)| { + ( + self.clbits + .find(old_clbits.get(k).unwrap().bind(py)) + .unwrap(), + v, + ) }) .collect(); @@ -1233,7 +1268,7 @@ def _format(operand): for (i, bit) in self.clbits.bits().iter().enumerate() { let raw_loc = bit_locations.get_item(bit)?.unwrap(); let loc = raw_loc.downcast::().unwrap(); - loc.borrow_mut().index = i; + loc.borrow_mut().set_index(i); bit_locations.set_item(bit, loc)?; } Ok(()) @@ -1341,7 +1376,7 @@ def _format(operand): let mut busy_bits = Vec::new(); for bit in qubits.iter() { - if !self.is_wire_idle(py, &Wire::Qubit(*bit))? { + if !self.is_wire_idle(&Wire::Qubit(*bit))? { busy_bits.push(self.qubits.get(*bit).unwrap()); } } @@ -1368,7 +1403,7 @@ def _format(operand): // Remove DAG in/out nodes etc. for bit in qubits.iter() { - self.remove_idle_wire(py, Wire::Qubit(*bit))?; + self.remove_idle_wire(Wire::Qubit(*bit))?; } // Copy the current qubit mapping so we can use it while remapping @@ -1377,32 +1412,31 @@ def _format(operand): // Remove the qubit indices, which will invalidate our mapping of Qubit to // Python bits throughout the entire DAG. - self.qubits.remove_indices(py, qubits.clone())?; + self.qubits.remove_indices(py, qubits)?; // Update input/output maps to use new Qubits. - let io_mapping: HashMap = self - .qubit_io_map + self.qubit_input_map = self + .qubit_input_map .drain(..) - .enumerate() - .filter_map(|(k, v)| { - let qubit = Qubit(k as u32); - if qubits.contains(&qubit) { - None - } else { - Some(( - self.qubits - .find(old_qubits.get(qubit).unwrap().bind(py)) - .unwrap(), - v, - )) - } + .map(|(k, v)| { + ( + self.qubits + .find(old_qubits.get(k).unwrap().bind(py)) + .unwrap(), + v, + ) }) .collect(); - - self.qubit_io_map = (0..io_mapping.len()) - .map(|idx| { - let qubit = Qubit(idx as u32); - io_mapping[&qubit] + self.qubit_output_map = self + .qubit_output_map + .drain(..) + .map(|(k, v)| { + ( + self.qubits + .find(old_qubits.get(k).unwrap().bind(py)) + .unwrap(), + v, + ) }) .collect(); @@ -1443,7 +1477,7 @@ def _format(operand): for (i, bit) in self.qubits.bits().iter().enumerate() { let raw_loc = bit_locations.get_item(bit)?.unwrap(); let loc = raw_loc.downcast::().unwrap(); - loc.borrow_mut().index = i; + loc.borrow_mut().set_index(i); bit_locations.set_item(bit, loc)?; } Ok(()) @@ -1638,7 +1672,48 @@ def _format(operand): ) -> PyResult<()> { if let NodeType::Operation(inst) = self.pack_into(py, node)? { if check { - self.check_op_addition(py, &inst)?; + if let Some(condition) = inst.condition() { + self._check_condition(py, inst.op.name(), condition.bind(py))?; + } + + for b in self.qargs_cache.intern(inst.qubits) { + if !self.qubit_output_map.contains_key(b) { + return Err(DAGCircuitError::new_err(format!( + "qubit {} not found in output map", + self.qubits.get(*b).unwrap() + ))); + } + } + + for b in self.cargs_cache.intern(inst.clbits) { + if !self.clbit_output_map.contains_key(b) { + return Err(DAGCircuitError::new_err(format!( + "clbit {} not found in output map", + self.clbits.get(*b).unwrap() + ))); + } + } + + if self.may_have_additional_wires(py, &inst) { + let (clbits, vars) = + self.additional_wires(py, inst.op.view(), inst.condition())?; + for b in clbits { + if !self.clbit_output_map.contains_key(&b) { + return Err(DAGCircuitError::new_err(format!( + "clbit {} not found in output map", + self.clbits.get(b).unwrap() + ))); + } + } + for v in vars { + if !self.var_output_map.contains_key(&v) { + return Err(DAGCircuitError::new_err(format!( + "var {} not found in output map", + v + ))); + } + } + } } self.push_back(py, inst)?; @@ -1692,11 +1767,52 @@ def _format(operand): params: (!py_op.params.is_empty()).then(|| Box::new(py_op.params)), extra_attrs: py_op.extra_attrs, #[cfg(feature = "cache_pygates")] - py_op: op.unbind().into(), + py_op: RefCell::new(Some(op.unbind())), }; if check { - self.check_op_addition(py, &instr)?; + if let Some(condition) = instr.condition() { + self._check_condition(py, instr.op.name(), condition.bind(py))?; + } + + for b in self.qargs_cache.intern(instr.qubits) { + if !self.qubit_output_map.contains_key(b) { + return Err(DAGCircuitError::new_err(format!( + "qubit {} not found in output map", + self.qubits.get(*b).unwrap() + ))); + } + } + + for b in self.cargs_cache.intern(instr.clbits) { + if !self.clbit_output_map.contains_key(b) { + return Err(DAGCircuitError::new_err(format!( + "clbit {} not found in output map", + self.clbits.get(*b).unwrap() + ))); + } + } + + if self.may_have_additional_wires(py, &instr) { + let (clbits, vars) = + self.additional_wires(py, instr.op.view(), instr.condition())?; + for b in clbits { + if !self.clbit_output_map.contains_key(&b) { + return Err(DAGCircuitError::new_err(format!( + "clbit {} not found in output map", + self.clbits.get(b).unwrap() + ))); + } + } + for v in vars { + if !self.var_output_map.contains_key(&v) { + return Err(DAGCircuitError::new_err(format!( + "var {} not found in output map", + v + ))); + } + } + } } self.push_back(py, instr)? }; @@ -1748,11 +1864,52 @@ def _format(operand): params: (!py_op.params.is_empty()).then(|| Box::new(py_op.params)), extra_attrs: py_op.extra_attrs, #[cfg(feature = "cache_pygates")] - py_op: op.unbind().into(), + py_op: RefCell::new(Some(op.unbind())), }; if check { - self.check_op_addition(py, &instr)?; + if let Some(condition) = instr.condition() { + self._check_condition(py, instr.op.name(), condition.bind(py))?; + } + + for b in self.qargs_cache.intern(instr.qubits) { + if !self.qubit_output_map.contains_key(b) { + return Err(DAGCircuitError::new_err(format!( + "qubit {} not found in output map", + self.qubits.get(*b).unwrap() + ))); + } + } + + for b in self.cargs_cache.intern(instr.clbits) { + if !self.clbit_output_map.contains_key(b) { + return Err(DAGCircuitError::new_err(format!( + "clbit {} not found in output map", + self.clbits.get(*b).unwrap() + ))); + } + } + + if self.may_have_additional_wires(py, &instr) { + let (clbits, vars) = + self.additional_wires(py, instr.op.view(), instr.condition())?; + for b in clbits { + if !self.clbit_output_map.contains_key(&b) { + return Err(DAGCircuitError::new_err(format!( + "clbit {} not found in output map", + self.clbits.get(b).unwrap() + ))); + } + } + for v in vars { + if !self.var_output_map.contains_key(&v) { + return Err(DAGCircuitError::new_err(format!( + "var {} not found in output map", + v + ))); + } + } + } } self.push_front(py, instr)? }; @@ -1909,7 +2066,7 @@ def _format(operand): Py::new(py, slf.clone())?.into_bound(py).borrow_mut() }; - dag.global_phase = add_global_phase(py, &dag.global_phase, &other.global_phase)?; + dag.global_phase = dag.global_phase.add(py, &other.global_phase); for (gate, cals) in other.calibrations.iter() { let calibrations = match dag.calibrations.get(gate) { @@ -1947,7 +2104,7 @@ def _format(operand): } } for var in other.iter_declared_vars(py)?.bind(py) { - dag.add_declared_var(py, &var?)?; + dag.add_declared_var(&var?)?; } let variable_mapper = PyVariableMapper::new( @@ -1964,9 +2121,8 @@ def _format(operand): let bit = other.qubits.get(*q).unwrap().bind(py); let m_wire = edge_map.get_item(bit)?.unwrap_or_else(|| bit.clone()); let wire_in_dag = dag.qubits.find(&m_wire); - if wire_in_dag.is_none() - || (dag.qubit_io_map.len() - 1 < wire_in_dag.unwrap().0 as usize) + || !dag.qubit_output_map.contains_key(&wire_in_dag.unwrap()) { return Err(DAGCircuitError::new_err(format!( "wire {} not in self", @@ -1980,7 +2136,7 @@ def _format(operand): let m_wire = edge_map.get_item(bit)?.unwrap_or_else(|| bit.clone()); let wire_in_dag = dag.clbits.find(&m_wire); if wire_in_dag.is_none() - || dag.clbit_io_map.len() - 1 < wire_in_dag.unwrap().0 as usize + || !dag.clbit_output_map.contains_key(&wire_in_dag.unwrap()) { return Err(DAGCircuitError::new_err(format!( "wire {} not in self", @@ -2089,10 +2245,13 @@ def _format(operand): /// DAGCircuitError: If the DAG is invalid fn idle_wires(&self, py: Python, ignore: Option<&Bound>) -> PyResult> { let mut result: Vec = Vec::new(); - let wires = (0..self.qubit_io_map.len()) - .map(|idx| Wire::Qubit(Qubit(idx as u32))) - .chain((0..self.clbit_io_map.len()).map(|idx| Wire::Clbit(Clbit(idx as u32)))) - .chain(self.var_input_map.keys(py).map(Wire::Var)); + let wires = self + .qubit_input_map + .keys() + .cloned() + .map(Wire::Qubit) + .chain(self.clbit_input_map.keys().cloned().map(Wire::Clbit)) + .chain(self.var_input_map.keys().map(Wire::Var)); match ignore { Some(ignore) => { // Convert the list to a Rust set. @@ -2101,7 +2260,7 @@ def _format(operand): .map(|s| s.extract()) .collect::>>()?; for wire in wires { - let nodes_found = self.nodes_on_wire(py, &wire, true).into_iter().any(|node| { + let nodes_found = self.nodes_on_wire(&wire, true).into_iter().any(|node| { let weight = self.dag.node_weight(node).unwrap(); if let NodeType::Operation(packed) = weight { !ignore_set.contains(packed.op.name()) @@ -2121,7 +2280,7 @@ def _format(operand): } None => { for wire in wires { - if self.is_wire_idle(py, &wire)? { + if self.is_wire_idle(&wire)? { result.push(match wire { Wire::Qubit(qubit) => self.qubits.get(qubit).unwrap().clone_ref(py), Wire::Clbit(clbit) => self.clbits.get(clbit).unwrap().clone_ref(py), @@ -2152,7 +2311,7 @@ def _format(operand): /// ``recurse=True``, or any control flow is present in a non-recursive call. #[pyo3(signature= (*, recurse=false))] fn size(&self, py: Python, recurse: bool) -> PyResult { - let mut length = self.dag.node_count() - (self.width() * 2); + let mut length = self.dag.node_count() - (self.width() + self.num_vars()) * 2; if !recurse { if CONTROL_FLOW_OP_NAMES .iter() @@ -2232,7 +2391,7 @@ def _format(operand): /// flow is present in a non-recursive call. #[pyo3(signature= (*, recurse=false))] fn depth(&self, py: Python, recurse: bool) -> PyResult { - if self.qubits.is_empty() && self.clbits.is_empty() && self.vars_info.is_empty() { + if self.qubits.is_empty() { return Ok(0); } @@ -2300,28 +2459,24 @@ def _format(operand): /// with the new function DAGCircuit.num_qubits replacing the former /// semantic of DAGCircuit.width(). fn width(&self) -> usize { - self.qubits.len() + self.clbits.len() + self.vars_info.len() + self.qubits.len() + self.clbits.len() } /// Return the total number of qubits used by the circuit. /// num_qubits() replaces former use of width(). /// DAGCircuit.width() now returns qubits + clbits for /// consistency with Circuit.width() [qiskit-terra #2564]. - pub fn num_qubits(&self) -> usize { + fn num_qubits(&self) -> usize { self.qubits.len() } /// Return the total number of classical bits used by the circuit. - pub fn num_clbits(&self) -> usize { + fn num_clbits(&self) -> usize { self.clbits.len() } /// Compute how many components the circuit can decompose into. fn num_tensor_factors(&self) -> usize { - // This function was forked from rustworkx's - // number_weekly_connected_components() function as of 0.15.0: - // https://github.com/Qiskit/rustworkx/blob/0.15.0/src/connectivity/mod.rs#L215-L235 - let mut weak_components = self.dag.node_count(); let mut vertex_sets = UnionFind::new(self.dag.node_bound()); for edge in self.dag.edge_references() { @@ -2341,29 +2496,78 @@ def _format(operand): let phase_is_close = |self_phase: f64, other_phase: f64| -> bool { ((self_phase - other_phase + PI).rem_euclid(2. * PI) - PI).abs() <= 1.0e-10 }; - let normalize_param = |param: &Param| { - if let Param::ParameterExpression(ob) = param { - ob.bind(py) - .call_method0(intern!(py, "numeric")) - .ok() - .map(|ob| ob.extract::()) - .unwrap_or_else(|| Ok(param.clone())) - } else { - Ok(param.clone()) - } - }; - - let phase_eq = match [ - normalize_param(&self.global_phase)?, - normalize_param(&other.global_phase)?, - ] { + match [&self.global_phase, &other.global_phase] { [Param::Float(self_phase), Param::Float(other_phase)] => { - Ok(phase_is_close(self_phase, other_phase)) + if !phase_is_close(*self_phase, *other_phase) { + return Ok(false); + } + } + [Param::Float(self_phase), Param::ParameterExpression(other_phase_param)] => { + let other_phase = if let Ok(other_phase) = + other_phase_param.call_method0(py, intern!(py, "numeric")) + { + other_phase.extract::(py)? + } else { + Param::ParameterExpression(other_phase_param.clone_ref(py)) + }; + if let Param::Float(other_phase) = other_phase { + if !phase_is_close(*self_phase, other_phase) { + return Ok(false); + } + } else if !self.global_phase.eq(&other.global_phase, py)? { + return Ok(false); + } + } + [Param::ParameterExpression(self_phase_param), Param::ParameterExpression(other_phase_param)] => + { + let self_phase = if let Ok(self_phase) = + self_phase_param.call_method0(py, intern!(py, "numeric")) + { + self_phase.extract::(py)? + } else { + Param::ParameterExpression(self_phase_param.clone_ref(py)) + }; + let other_phase = if let Ok(other_phase) = + other_phase_param.call_method0(py, intern!(py, "numeric")) + { + other_phase.extract::(py)? + } else { + Param::ParameterExpression(other_phase_param.clone_ref(py)) + }; + match [self_phase, other_phase] { + [Param::Float(self_phase), Param::Float(other_phase)] => { + if !phase_is_close(self_phase, other_phase) { + return Ok(false); + } + } + _ => { + if !self.global_phase.eq(&other.global_phase, py)? { + return Ok(false); + } + } + } + } + [Param::ParameterExpression(self_phase_param), Param::Float(other_phase)] => { + let self_phase = if let Ok(self_phase) = + self_phase_param.call_method0(py, intern!(py, "numeric")) + { + self_phase.extract::(py)? + } else { + Param::ParameterExpression(self_phase_param.clone_ref(py)) + }; + if let Param::Float(self_phase) = self_phase { + if !phase_is_close(self_phase, *other_phase) { + return Ok(false); + } + } else if !self.global_phase.eq(&other.global_phase, py)? { + return Ok(false); + } + } + _ => { + if !self.global_phase.eq(&other.global_phase, py)? { + return Ok(false); + } } - _ => self.global_phase.eq(py, &other.global_phase), - }?; - if !phase_eq { - return Ok(false); } if self.calibrations.len() != other.calibrations.len() { return Ok(false); @@ -2470,64 +2674,64 @@ def _format(operand): if inst1.op.name() != inst2.op.name() { return Ok(false); } - let check_args = || -> bool { - let node1_qargs = self.qargs_cache.intern(inst1.qubits); - let node2_qargs = other.qargs_cache.intern(inst2.qubits); - let node1_cargs = self.cargs_cache.intern(inst1.clbits); - let node2_cargs = other.cargs_cache.intern(inst2.clbits); - if SEMANTIC_EQ_SYMMETRIC.contains(&inst1.op.name()) { - let node1_qargs = - node1_qargs.iter().copied().collect::>(); - let node2_qargs = - node2_qargs.iter().copied().collect::>(); - let node1_cargs = - node1_cargs.iter().copied().collect::>(); - let node2_cargs = - node2_cargs.iter().copied().collect::>(); - if node1_qargs != node2_qargs || node1_cargs != node2_cargs { - return false; + let node1_qargs = self.qargs_cache.intern(inst1.qubits); + let node2_qargs = other.qargs_cache.intern(inst2.qubits); + let node1_cargs = self.cargs_cache.intern(inst1.clbits); + let node2_cargs = other.cargs_cache.intern(inst2.clbits); + match [inst1.op.view(), inst2.op.view()] { + [OperationRef::Standard(op1), OperationRef::Standard(_op2)] => { + if !inst1.py_op_eq(py, inst2)? { + return Ok(false); } - } else if node1_qargs != node2_qargs || node1_cargs != node2_cargs { - return false; - } - true - }; - let check_conditions = || -> PyResult { - if let Some(cond1) = inst1 - .extra_attrs - .as_ref() - .and_then(|attrs| attrs.condition.as_ref()) - { - if let Some(cond2) = inst2 + if SEMANTIC_EQ_SYMMETRIC.contains(&op1.name()) { + let node1_qargs = + node1_qargs.iter().copied().collect::>(); + let node2_qargs = + node2_qargs.iter().copied().collect::>(); + let node1_cargs = + node1_cargs.iter().copied().collect::>(); + let node2_cargs = + node2_cargs.iter().copied().collect::>(); + if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + } else if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + let conditions_eq = if let Some(cond1) = inst1 .extra_attrs .as_ref() .and_then(|attrs| attrs.condition.as_ref()) { - legacy_condition_eq - .call1((cond1, cond2, &self_bit_indices, &other_bit_indices))? - .extract::() + if let Some(cond2) = inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + legacy_condition_eq + .call1(( + cond1, + cond2, + &self_bit_indices, + &other_bit_indices, + ))? + .extract::()? + } else { + false + } } else { - Ok(false) - } - } else { - Ok(inst2 - .extra_attrs - .as_ref() - .and_then(|attrs| attrs.condition.as_ref()) - .is_none()) - } - }; - - match [inst1.op.view(), inst2.op.view()] { - [OperationRef::Standard(_op1), OperationRef::Standard(_op2)] => { - Ok(inst1.py_op_eq(py, inst2)? - && check_args() - && check_conditions()? - && inst1 - .params_view() - .iter() - .zip(inst2.params_view().iter()) - .all(|(a, b)| a.is_close(py, b, 1e-10).unwrap())) + inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + .is_none() + }; + let params_eq = inst1 + .params_view() + .iter() + .zip(inst2.params_view().iter()) + .all(|(a, b)| a.is_close(b, py, 1e-10).unwrap()); + Ok(conditions_eq && params_eq) } [OperationRef::Instruction(op1), OperationRef::Instruction(op2)] => { if op1.control_flow() && op2.control_flow() { @@ -2535,43 +2739,211 @@ def _format(operand): let n2 = other.unpack_into(py, NodeIndex::new(0), n2)?; let name = op1.name(); if name == "if_else" || name == "while_loop" { - condition_op_check + if name != op2.name() { + return Ok(false); + } + return condition_op_check .call1((n1, n2, &self_bit_indices, &other_bit_indices))? - .extract() + .extract(); } else if name == "switch_case" { - switch_case_op_check + let res = switch_case_op_check .call1((n1, n2, &self_bit_indices, &other_bit_indices))? - .extract() + .extract(); + return res; } else if name == "for_loop" { - for_loop_op_check + return for_loop_op_check .call1((n1, n2, &self_bit_indices, &other_bit_indices))? - .extract() + .extract(); } else { - Err(PyRuntimeError::new_err(format!( + return Err(PyRuntimeError::new_err(format!( "unhandled control-flow operation: {}", name - ))) + ))); } - } else { - Ok(inst1.py_op_eq(py, inst2)? - && check_args() - && check_conditions()?) } + if !inst1.py_op_eq(py, inst2)? { + return Ok(false); + } + if SEMANTIC_EQ_SYMMETRIC.contains(&op1.name()) { + let node1_qargs = + node1_qargs.iter().copied().collect::>(); + let node2_qargs = + node2_qargs.iter().copied().collect::>(); + let node1_cargs = + node1_cargs.iter().copied().collect::>(); + let node2_cargs = + node2_cargs.iter().copied().collect::>(); + if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + } else if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + + let conditions_eq = if let Some(cond1) = inst1 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + if let Some(cond2) = inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + legacy_condition_eq + .call1(( + cond1, + cond2, + &self_bit_indices, + &other_bit_indices, + ))? + .extract::()? + } else { + false + } + } else { + inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + .is_none() + }; + Ok(conditions_eq) } - [OperationRef::Gate(_op1), OperationRef::Gate(_op2)] => { - Ok(inst1.py_op_eq(py, inst2)? && check_args() && check_conditions()?) + [OperationRef::Gate(op1), OperationRef::Gate(_op2)] => { + if !inst1.py_op_eq(py, inst2)? { + return Ok(false); + } + if SEMANTIC_EQ_SYMMETRIC.contains(&op1.name()) { + let node1_qargs = + node1_qargs.iter().copied().collect::>(); + let node2_qargs = + node2_qargs.iter().copied().collect::>(); + let node1_cargs = + node1_cargs.iter().copied().collect::>(); + let node2_cargs = + node2_cargs.iter().copied().collect::>(); + if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + } else if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + + let conditions_eq = if let Some(cond1) = inst1 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + if let Some(cond2) = inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + legacy_condition_eq + .call1(( + cond1, + cond2, + &self_bit_indices, + &other_bit_indices, + ))? + .extract::()? + } else { + false + } + } else { + inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + .is_none() + }; + Ok(conditions_eq) } [OperationRef::Operation(_op1), OperationRef::Operation(_op2)] => { - Ok(inst1.py_op_eq(py, inst2)? && check_args()) + if !inst1.py_op_eq(py, inst2)? { + return Ok(false); + } + Ok(node1_qargs == node2_qargs || node1_cargs == node2_cargs) } // Handle the case we end up with a pygate for a standardgate // this typically only happens if it's a ControlledGate in python // and we have mutable state set. [OperationRef::Standard(_op1), OperationRef::Gate(_op2)] => { - Ok(inst1.py_op_eq(py, inst2)? && check_args() && check_conditions()?) + if !inst1.py_op_eq(py, inst2)? { + return Ok(false); + } + if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + + let conditions_eq = if let Some(cond1) = inst1 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + if let Some(cond2) = inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + legacy_condition_eq + .call1(( + cond1, + cond2, + &self_bit_indices, + &other_bit_indices, + ))? + .extract::()? + } else { + false + } + } else { + inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + .is_none() + }; + Ok(conditions_eq) } [OperationRef::Gate(_op1), OperationRef::Standard(_op2)] => { - Ok(inst1.py_op_eq(py, inst2)? && check_args() && check_conditions()?) + if !inst1.py_op_eq(py, inst2)? { + return Ok(false); + } + if node1_qargs != node2_qargs || node1_cargs != node2_cargs { + return Ok(false); + } + + let conditions_eq = if let Some(cond1) = inst1 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + if let Some(cond2) = inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + { + legacy_condition_eq + .call1(( + cond1, + cond2, + &self_bit_indices, + &other_bit_indices, + ))? + .extract::()? + } else { + false + } + } else { + inst2 + .extra_attrs + .as_ref() + .and_then(|attrs| attrs.condition.as_ref()) + .is_none() + }; + Ok(conditions_eq) } _ => Ok(false), } @@ -2837,7 +3209,7 @@ def _format(operand): params: (!py_op.params.is_empty()).then(|| Box::new(py_op.params)), extra_attrs: py_op.extra_attrs, #[cfg(feature = "cache_pygates")] - py_op: op.unbind().into(), + py_op: RefCell::new(Some(op.unbind())), }); let new_node = self @@ -2849,9 +3221,9 @@ def _format(operand): ), })?; - self.increment_op(op_name.as_str()); + self.increment_op(op_name); for name in block_op_names { - self.decrement_op(name.as_str()); + self.decrement_op(name); } self.get_node(py, new_node) @@ -3003,8 +3375,7 @@ def _format(operand): }, None => { let raw_wires = input_dag.get_wires(py); - let binding = raw_wires?; - let wires = binding.bind(py); + let wires = raw_wires.bind(py); build_wire_map(wires)? } }; @@ -3207,10 +3578,6 @@ def _format(operand): unit: None, })); } - #[cfg(feature = "cache_pygates")] - { - new_inst.py_op.take(); - } } in_dag.push_back(py, new_inst)?; } @@ -3245,7 +3612,7 @@ def _format(operand): &var_map, )? }; - self.global_phase = add_global_phase(py, &self.global_phase, &input_dag.global_phase)?; + self.global_phase = self.global_phase.add(py, &input_dag.global_phase); let wire_map_dict = PyDict::new_bound(py); for (source, target) in clbit_wire_map.iter() { @@ -3321,7 +3688,7 @@ def _format(operand): .into(); #[cfg(feature = "cache_pygates")] { - new_inst.py_op = new_op.unbind().into(); + *new_inst.py_op.borrow_mut() = Some(new_op.unbind()); } } } @@ -3352,10 +3719,6 @@ def _format(operand): })) } } - #[cfg(feature = "cache_pygates")] - { - new_inst.py_op.take(); - } match new_inst.op.view() { OperationRef::Instruction(py_inst) => { py_inst @@ -3368,7 +3731,12 @@ def _format(operand): OperationRef::Operation(py_op) => { py_op.operation.setattr(py, "condition", new_condition)?; } - OperationRef::Standard(_) => {} + OperationRef::Standard(_) => { + #[cfg(feature = "cache_pygates")] + { + *new_inst.py_op.borrow_mut() = None + } + } } } } @@ -3410,10 +3778,11 @@ def _format(operand): &mut self, node: &Bound, op: &Bound, - inplace: bool, + // Unused in Rust space because `DAGOpNode` is no longer the data-at-rest format. + #[allow(unused_variables)] inplace: bool, propagate_condition: bool, ) -> PyResult> { - let mut node: PyRefMut = match node.downcast() { + let node: PyRefMut = match node.downcast() { Ok(node) => node.borrow_mut(), Err(_) => return Err(DAGCircuitError::new_err("Only DAGOpNodes can be replaced.")), }; @@ -3537,18 +3906,6 @@ def _format(operand): ))); } - if inplace { - node.instruction.operation = new_op.operation.clone(); - node.instruction.params = new_op.params.clone(); - node.instruction.extra_attrs = extra_attrs.clone(); - #[cfg(feature = "cache_pygates")] - { - node.instruction.py_op = py_op_cache - .as_ref() - .map(|ob| OnceCell::from(ob.clone_ref(py))) - .unwrap_or_default(); - } - } // Clone op data, as it will be moved into the PackedInstruction let new_weight = NodeType::Operation(PackedInstruction { op: new_op.operation.clone(), @@ -3557,22 +3914,19 @@ def _format(operand): params: (!new_op.params.is_empty()).then(|| new_op.params.into()), extra_attrs, #[cfg(feature = "cache_pygates")] - py_op: py_op_cache.map(OnceCell::from).unwrap_or_default(), + py_op: RefCell::new(py_op_cache), }); + let node_index = node.as_ref().node.unwrap(); if let Some(weight) = self.dag.node_weight_mut(node_index) { *weight = new_weight; } // Update self.op_names - self.decrement_op(old_packed.op.name()); - self.increment_op(new_op.operation.name()); + self.decrement_op(old_packed.op.name().to_string()); + self.increment_op(new_op.operation.name().to_string()); - if inplace { - Ok(node.into_py(py)) - } else { - self.get_node(py, node_index) - } + self.get_node(py, node_index) } /// Decompose the circuit into sets of qubits with no gates connecting them. @@ -3596,7 +3950,7 @@ def _format(operand): remove_idle_qubits: bool, vars_mode: &str, ) -> PyResult> { - let connected_components = rustworkx_core::connectivity::connected_components(&self.dag); + let connected_components = core_connected_components(&self.dag); let dags = PyList::empty_bound(py); for comp_nodes in connected_components.iter() { @@ -3613,34 +3967,34 @@ def _format(operand): match self.dag.node_weight(*node) { Some(w) => match w { NodeType::ClbitIn(b) => { - let clbit_in = new_dag.clbit_io_map[b.0 as usize][0]; - node_map.insert(*node, clbit_in); + let clbit_in = new_dag.clbit_input_map.get(b).unwrap(); + node_map.insert(*node, *clbit_in); } NodeType::ClbitOut(b) => { - let clbit_out = new_dag.clbit_io_map[b.0 as usize][1]; - node_map.insert(*node, clbit_out); + let clbit_out = new_dag.clbit_output_map.get(b).unwrap(); + node_map.insert(*node, *clbit_out); } NodeType::QubitIn(q) => { - let qbit_in = new_dag.qubit_io_map[q.0 as usize][0]; - node_map.insert(*node, qbit_in); + let qbit_in = new_dag.qubit_input_map.get(q).unwrap(); + node_map.insert(*node, *qbit_in); non_classical = true; } NodeType::QubitOut(q) => { - let qbit_out = new_dag.qubit_io_map[q.0 as usize][1]; - node_map.insert(*node, qbit_out); + let qbit_out = new_dag.qubit_output_map.get(q).unwrap(); + node_map.insert(*node, *qbit_out); non_classical = true; } NodeType::VarIn(v) => { - let var_in = new_dag.var_input_map.get(py, v).unwrap(); + let var_in = new_dag.var_input_map.get(v).unwrap(); node_map.insert(*node, var_in); } NodeType::VarOut(v) => { - let var_out = new_dag.var_output_map.get(py, v).unwrap(); + let var_out = new_dag.var_output_map.get(v).unwrap(); node_map.insert(*node, var_out); } NodeType::Operation(pi) => { let new_node = new_dag.dag.add_node(NodeType::Operation(pi.clone())); - new_dag.increment_op(pi.op.name()); + new_dag.increment_op(pi.op.name().to_string()); node_map.insert(*node, new_node); non_classical = true; } @@ -3651,52 +4005,35 @@ def _format(operand): if !non_classical { continue; } - let node_filter = |node: NodeIndex| -> bool { node_map.contains_key(&node) }; - let filtered = NodeFiltered(&self.dag, node_filter); + // Handling the edges in the new dag + for node in comp_nodes { + // Since the nodes comprise an SCC, it's enough to just look at the (e.g.) outgoing edges + let outgoing_edges = self.dag.edges_directed(*node, Direction::Outgoing); - // Remove the edges added by copy_empty_like (as idle wires) to avoid duplication - new_dag.dag.clear_edges(); - for edge in filtered.edge_references() { - let new_source = node_map[&edge.source()]; - let new_target = node_map[&edge.target()]; - new_dag - .dag - .add_edge(new_source, new_target, edge.weight().clone()); - } - // Add back any edges for idle wires - for (qubit, [in_node, out_node]) in new_dag - .qubit_io_map - .iter() - .enumerate() - .map(|(idx, indices)| (Qubit(idx as u32), indices)) - { - if new_dag.dag.edges(*in_node).next().is_none() { - new_dag - .dag - .add_edge(*in_node, *out_node, Wire::Qubit(qubit)); - } - } - for (clbit, [in_node, out_node]) in new_dag - .clbit_io_map - .iter() - .enumerate() - .map(|(idx, indices)| (Clbit(idx as u32), indices)) - { - if new_dag.dag.edges(*in_node).next().is_none() { - new_dag - .dag - .add_edge(*in_node, *out_node, Wire::Clbit(clbit)); + // Remove the edges added by copy_empty_like (as idle wires) to avoid duplication + if let Some(NodeType::QubitIn(_)) | Some(NodeType::ClbitIn(_)) = + self.dag.node_weight(*node) + { + let edges: Vec = new_dag.dag.edges(*node).map(|e| e.id()).collect(); + for edge in edges { + new_dag.dag.remove_edge(edge); + } } - } - for (var, in_node) in new_dag.var_input_map.iter(py) { - if new_dag.dag.edges(in_node).next().is_none() { - let out_node = new_dag.var_output_map.get(py, &var).unwrap(); + + for e in outgoing_edges { + let (source, target) = (e.source(), e.target()); + let edge_weight = e.weight(); + let (source_new, target_new) = ( + node_map.get(&source).unwrap(), + node_map.get(&target).unwrap(), + ); new_dag .dag - .add_edge(in_node, out_node, Wire::Var(var.clone_ref(py))); + .add_edge(*source_new, *target_new, edge_weight.clone()); } } + if remove_idle_qubits { let idle_wires: Vec> = new_dag .idle_wires(py, None)? @@ -3893,28 +4230,17 @@ def _format(operand): include_directives: bool, ) -> PyResult>> { let mut nodes = Vec::new(); - let filter_is_nonstandard = if let Some(op) = op { - op.getattr(intern!(py, "_standard_gate")).ok().is_none() - } else { - true - }; for (node, weight) in self.dag.node_references() { if let NodeType::Operation(packed) = &weight { if !include_directives && packed.op.directive() { continue; } if let Some(op_type) = op { - // This middle catch is to avoid Python-space operation creation for most uses of - // `op`; we're usually just looking for control-flow ops, and standard gates - // aren't control-flow ops. - if !(filter_is_nonstandard && packed.op.try_standard_gate().is_some()) - && packed.op.py_op_is_instance(op_type)? - { - nodes.push(self.unpack_into(py, node, weight)?); + if !packed.op.py_op_is_instance(op_type)? { + continue; } - } else { - nodes.push(self.unpack_into(py, node, weight)?); } + nodes.push(self.unpack_into(py, node, weight)?); } } Ok(nodes) @@ -4120,8 +4446,8 @@ def _format(operand): fn classical_predecessors(&self, py: Python, node: &DAGNode) -> PyResult> { let edges = self.dag.edges_directed(node.node.unwrap(), Incoming); let filtered = edges.filter_map(|e| match e.weight() { - Wire::Qubit(_) => None, - _ => Some(e.source()), + Wire::Clbit(_) => Some(e.source()), + _ => None, }); let predecessors: PyResult> = filtered.unique().map(|i| self.get_node(py, i)).collect(); @@ -4530,8 +4856,6 @@ def _format(operand): self.qubits.find(wire).map(Wire::Qubit) } else if wire.is_instance(imports::CLBIT.get_bound(py))? { self.clbits.find(wire).map(Wire::Clbit) - } else if self.var_input_map.contains_key(py, &wire.clone().unbind()) { - Some(Wire::Var(wire.clone().unbind())) } else { None } @@ -4543,7 +4867,7 @@ def _format(operand): })?; let nodes = self - .nodes_on_wire(py, &wire, only_ops) + .nodes_on_wire(&wire, only_ops) .into_iter() .map(|n| self.get_node(py, n)) .collect::>>()?; @@ -4565,7 +4889,7 @@ def _format(operand): if !recurse || !CONTROL_FLOW_OP_NAMES .iter() - .any(|x| self.op_names.contains_key(*x)) + .any(|x| self.op_names.contains_key(&x.to_string())) { Ok(self.op_names.to_object(py)) } else { @@ -4652,19 +4976,16 @@ def _format(operand): qubit )) })?; - let output_node_index = self - .qubit_io_map - .get(output_qubit.0 as usize) - .map(|x| x[1]) - .ok_or_else(|| { - DAGCircuitError::new_err(format!( - "The given qubit {:?} is not present in qubit_output_map", - qubit - )) - })?; + let output_node_index = self.qubit_output_map.get(&output_qubit).ok_or_else(|| { + DAGCircuitError::new_err(format!( + "The given qubit {:?} is not present in qubit_output_map", + qubit + )) + })?; let mut qubits_in_cone: HashSet<&Qubit> = HashSet::from([&output_qubit]); - let mut queue: VecDeque = self.quantum_predecessors(output_node_index).collect(); + let mut queue: VecDeque = + self.quantum_predecessors(*output_node_index).collect(); // The processed_non_directive_nodes stores the set of processed non-directive nodes. // This is an optimization to avoid considering the same non-directive node multiple @@ -4723,15 +5044,15 @@ def _format(operand): } /// Return a dictionary of circuit properties. - fn properties(&self, py: Python) -> PyResult> { + fn properties(&self, py: Python) -> PyResult> { Ok(HashMap::from_iter([ - ("size", self.size(py, false)?.into_py(py)), - ("depth", self.depth(py, false)?.into_py(py)), - ("width", self.width().into_py(py)), - ("qubits", self.num_qubits().into_py(py)), - ("bits", self.num_clbits().into_py(py)), - ("factors", self.num_tensor_factors().into_py(py)), - ("operations", self.count_ops(py, true)?), + ("size", self.size(py, false)?), + ("depth", self.depth(py, false)?), + ("width", self.width()), + ("qubits", self.num_qubits()), + ("bits", self.num_clbits()), + ("factors", self.num_tensor_factors()), + // ("operations", self.count_ops(true)?), ])) } @@ -4778,10 +5099,6 @@ def _format(operand): Ok(PyString::new_bound(py, std::str::from_utf8(&buffer)?)) } - /// Add an input variable to the circuit. - /// - /// Args: - /// var: the variable to add. fn add_input_var(&mut self, py: Python, var: &Bound) -> PyResult<()> { if !self.vars_by_type[DAGVarType::Capture as usize] .bind(py) @@ -4794,10 +5111,6 @@ def _format(operand): self.add_var(py, var, DAGVarType::Input) } - /// Add a captured variable to the circuit. - /// - /// Args: - /// var: the variable to add. fn add_captured_var(&mut self, py: Python, var: &Bound) -> PyResult<()> { if !self.vars_by_type[DAGVarType::Input as usize] .bind(py) @@ -4810,27 +5123,20 @@ def _format(operand): self.add_var(py, var, DAGVarType::Capture) } - /// Add a declared local variable to the circuit. - /// - /// Args: - /// var: the variable to add. - fn add_declared_var(&mut self, py: Python, var: &Bound) -> PyResult<()> { - self.add_var(py, var, DAGVarType::Declare) + fn add_declared_var(&mut self, var: &Bound) -> PyResult<()> { + self.add_var(var.py(), var, DAGVarType::Declare) } - /// Total number of classical variables tracked by the circuit. #[getter] fn num_vars(&self) -> usize { self.vars_info.len() } - /// Number of input classical variables tracked by the circuit. #[getter] fn num_input_vars(&self, py: Python) -> usize { self.vars_by_type[DAGVarType::Input as usize].bind(py).len() } - /// Number of captured classical variables tracked by the circuit. #[getter] fn num_captured_vars(&self, py: Python) -> usize { self.vars_by_type[DAGVarType::Capture as usize] @@ -4838,7 +5144,6 @@ def _format(operand): .len() } - /// Number of declared local classical variables tracked by the circuit. #[getter] fn num_declared_vars(&self, py: Python) -> usize { self.vars_by_type[DAGVarType::Declare as usize] @@ -4846,10 +5151,6 @@ def _format(operand): .len() } - /// Is this realtime variable in the DAG? - /// - /// Args: - /// var: the variable or name to check. fn has_var(&self, var: &Bound) -> PyResult { match var.extract::() { Ok(name) => Ok(self.vars_info.contains_key(&name)), @@ -4864,7 +5165,6 @@ def _format(operand): } } - /// Iterable over the input classical variables tracked by the circuit. fn iter_input_vars(&self, py: Python) -> PyResult> { Ok(self.vars_by_type[DAGVarType::Input as usize] .bind(py) @@ -4874,7 +5174,6 @@ def _format(operand): .unbind()) } - /// Iterable over the captured classical variables tracked by the circuit. fn iter_captured_vars(&self, py: Python) -> PyResult> { Ok(self.vars_by_type[DAGVarType::Capture as usize] .bind(py) @@ -4884,7 +5183,6 @@ def _format(operand): .unbind()) } - /// Iterable over the declared classical variables tracked by the circuit. fn iter_declared_vars(&self, py: Python) -> PyResult> { Ok(self.vars_by_type[DAGVarType::Declare as usize] .bind(py) @@ -4894,7 +5192,6 @@ def _format(operand): .unbind()) } - /// Iterable over all the classical variables tracked by the circuit. fn iter_vars(&self, py: Python) -> PyResult> { let out_set = PySet::empty_bound(py)?; for var_type_set in &self.vars_by_type { @@ -5003,7 +5300,7 @@ def _format(operand): old_index: usize, ) -> PyResult<()> { if let NodeType::Operation(inst) = self.pack_into(py, node)? { - self.increment_op(inst.op.name()); + self.increment_op(inst.op.name().to_string()); let new_index = self.dag.add_node(NodeType::Operation(inst)); let old_index: NodeIndex = NodeIndex::new(old_index); let (parent_index, edge_index, weight) = self @@ -5064,9 +5361,7 @@ impl DAGCircuit { match node { NodeType::Operation(inst) => Ok(inst.op.num_qubits() == 1 && inst.op.num_clbits() == 0 - && !inst.is_parameterized() - && (inst.op.try_standard_gate().is_some() - || inst.op.matrix(inst.params_view()).is_some()) + && inst.op.matrix(inst.params_view()).is_some() && inst.condition().is_none()), _ => Ok(false), } @@ -5107,27 +5402,23 @@ impl DAGCircuit { rustworkx_core::dag_algo::collect_bicolor_runs(&self.dag, filter_fn, color_fn).unwrap() } - fn increment_op(&mut self, op: &str) { - match self.op_names.get_mut(op) { - Some(count) => { - *count += 1; - } - None => { - self.op_names.insert(op.to_string(), 1); - } - } + fn increment_op(&mut self, op: String) { + self.op_names + .entry(op) + .and_modify(|count| *count += 1) + .or_insert(1); } - fn decrement_op(&mut self, op: &str) { - match self.op_names.get_mut(op) { - Some(count) => { - if *count > 1 { - *count -= 1; + fn decrement_op(&mut self, op: String) { + match self.op_names.entry(op) { + Entry::Occupied(mut o) => { + if *o.get() > 1usize { + *o.get_mut() -= 1; } else { - self.op_names.swap_remove(op); + o.swap_remove(); } } - None => panic!("Cannot decrement something not added!"), + _ => panic!("Cannot decrement something not added!"), } } @@ -5177,7 +5468,7 @@ impl DAGCircuit { } }; - self.increment_op(op_name); + self.increment_op(op_name.to_string()); let qubits_id = instr.qubits; let new_node = self.dag.add_node(NodeType::Operation(instr)); @@ -5188,16 +5479,16 @@ impl DAGCircuit { .qargs_cache .intern(qubits_id) .iter() - .map(|q| self.qubit_io_map.get(q.0 as usize).map(|x| x[1]).unwrap()) + .map(|q| self.qubit_output_map.get(q).copied().unwrap()) .chain( all_cbits .iter() - .map(|c| self.clbit_io_map.get(c.0 as usize).map(|x| x[1]).unwrap()), + .map(|c| self.clbit_output_map.get(c).copied().unwrap()), ) .chain( vars.iter() .flatten() - .map(|v| self.var_output_map.get(py, v).unwrap()), + .map(|v| self.var_output_map.get(v).unwrap()), ) .collect(); @@ -5243,7 +5534,7 @@ impl DAGCircuit { } }; - self.increment_op(op_name); + self.increment_op(op_name.to_string()); let qubits_id = inst.qubits; let new_node = self.dag.add_node(NodeType::Operation(inst)); @@ -5254,12 +5545,16 @@ impl DAGCircuit { .qargs_cache .intern(qubits_id) .iter() - .map(|q| self.qubit_io_map[q.0 as usize][0]) - .chain(all_cbits.iter().map(|c| self.clbit_io_map[c.0 as usize][0])) + .map(|q| self.qubit_input_map.get(q).copied().unwrap()) + .chain( + all_cbits + .iter() + .map(|c| self.clbit_input_map.get(c).copied().unwrap()), + ) .collect(); if let Some(vars) = vars { for var in vars { - input_nodes.push(self.var_input_map.get(py, &var).unwrap()); + input_nodes.push(self.var_input_map.get(&var).unwrap()); } } @@ -5340,19 +5635,13 @@ impl DAGCircuit { ) } - fn is_wire_idle(&self, py: Python, wire: &Wire) -> PyResult { + fn is_wire_idle(&self, wire: &Wire) -> PyResult { let (input_node, output_node) = match wire { - Wire::Qubit(qubit) => ( - self.qubit_io_map[qubit.0 as usize][0], - self.qubit_io_map[qubit.0 as usize][1], - ), - Wire::Clbit(clbit) => ( - self.clbit_io_map[clbit.0 as usize][0], - self.clbit_io_map[clbit.0 as usize][1], - ), + Wire::Qubit(qubit) => (self.qubit_input_map[qubit], self.qubit_output_map[qubit]), + Wire::Clbit(clbit) => (self.clbit_input_map[clbit], self.clbit_output_map[clbit]), Wire::Var(var) => ( - self.var_input_map.get(py, var).unwrap(), - self.var_output_map.get(py, var).unwrap(), + self.var_input_map.get(var).unwrap(), + self.var_output_map.get(var).unwrap(), ), }; @@ -5371,13 +5660,21 @@ impl DAGCircuit { } fn may_have_additional_wires(&self, py: Python, instr: &PackedInstruction) -> bool { - if instr.condition().is_some() { + let has_condition = match instr.condition() { + None => false, + Some(condition) => !condition.bind(py).is_none(), + }; + + if has_condition { return true; } let OperationRef::Instruction(inst) = instr.op.view() else { return false; }; - inst.control_flow() + inst.instruction + .bind(py) + .is_instance(imports::CONTROL_FLOW_OP.get_bound(py)) + .unwrap() || inst .instruction .bind(py) @@ -5410,6 +5707,7 @@ impl DAGCircuit { Ok((clbits, vars)) }; + // let mut bits = Vec::new(); let mut clbits = Vec::new(); let mut vars = Vec::new(); if let Some(condition) = condition { @@ -5438,7 +5736,7 @@ impl DAGCircuit { if let OperationRef::Instruction(inst) = op { let op = inst.instruction.bind(py); - if inst.control_flow() { + if op.is_instance(imports::CONTROL_FLOW_OP.get_bound(py))? { for var in op.call_method0("iter_captured_vars")?.iter()? { vars.push(var?.unbind()) } @@ -5489,39 +5787,47 @@ impl DAGCircuit { /// /// Raises: /// DAGCircuitError: if trying to add duplicate wire - fn add_wire(&mut self, py: Python, wire: Wire) -> PyResult<()> { + fn add_wire(&mut self, wire: Wire) -> PyResult<()> { let (in_node, out_node) = match wire { Wire::Qubit(qubit) => { - if (qubit.0 as usize) >= self.qubit_io_map.len() { - let input_node = self.dag.add_node(NodeType::QubitIn(qubit)); - let output_node = self.dag.add_node(NodeType::QubitOut(qubit)); - self.qubit_io_map.push([input_node, output_node]); - Ok((input_node, output_node)) - } else { - Err(DAGCircuitError::new_err("qubit wire already exists!")) + match ( + self.qubit_input_map.entry(qubit), + self.qubit_output_map.entry(qubit), + ) { + (indexmap::map::Entry::Vacant(input), indexmap::map::Entry::Vacant(output)) => { + Ok(( + *input.insert(self.dag.add_node(NodeType::QubitIn(qubit))), + *output.insert(self.dag.add_node(NodeType::QubitOut(qubit))), + )) + } + (_, _) => Err(DAGCircuitError::new_err("qubit wire already exists!")), } } Wire::Clbit(clbit) => { - if (clbit.0 as usize) >= self.clbit_io_map.len() { - let input_node = self.dag.add_node(NodeType::ClbitIn(clbit)); - let output_node = self.dag.add_node(NodeType::ClbitOut(clbit)); - self.clbit_io_map.push([input_node, output_node]); - Ok((input_node, output_node)) - } else { - Err(DAGCircuitError::new_err("classical wire already exists!")) + match ( + self.clbit_input_map.entry(clbit), + self.clbit_output_map.entry(clbit), + ) { + (indexmap::map::Entry::Vacant(input), indexmap::map::Entry::Vacant(output)) => { + Ok(( + *input.insert(self.dag.add_node(NodeType::ClbitIn(clbit))), + *output.insert(self.dag.add_node(NodeType::ClbitOut(clbit))), + )) + } + (_, _) => Err(DAGCircuitError::new_err("classical wire already exists!")), } } Wire::Var(ref var) => { - if self.var_input_map.contains_key(py, var) - || self.var_output_map.contains_key(py, var) - { + if self.var_input_map.contains_key(var) || self.var_output_map.contains_key(var) { return Err(DAGCircuitError::new_err("var wire already exists!")); } - let in_node = self.dag.add_node(NodeType::VarIn(var.clone_ref(py))); - let out_node = self.dag.add_node(NodeType::VarOut(var.clone_ref(py))); - self.var_input_map.insert(py, var.clone_ref(py), in_node); - self.var_output_map.insert(py, var.clone_ref(py), out_node); - Ok((in_node, out_node)) + Python::with_gil(|py| { + let in_node = self.dag.add_node(NodeType::VarIn(var.clone_ref(py))); + let out_node = self.dag.add_node(NodeType::VarOut(var.clone_ref(py))); + self.var_input_map.insert(var.clone_ref(py), in_node); + self.var_output_map.insert(var.clone_ref(py), out_node); + Ok((in_node, out_node)) + }) } }?; @@ -5532,12 +5838,12 @@ impl DAGCircuit { /// Get the nodes on the given wire. /// /// Note: result is empty if the wire is not in the DAG. - fn nodes_on_wire(&self, py: Python, wire: &Wire, only_ops: bool) -> Vec { + fn nodes_on_wire(&self, wire: &Wire, only_ops: bool) -> Vec { let mut nodes = Vec::new(); let mut current_node = match wire { - Wire::Qubit(qubit) => self.qubit_io_map.get(qubit.0 as usize).map(|x| x[0]), - Wire::Clbit(clbit) => self.clbit_io_map.get(clbit.0 as usize).map(|x| x[0]), - Wire::Var(var) => self.var_input_map.get(py, var), + Wire::Qubit(qubit) => self.qubit_input_map.get(qubit).copied(), + Wire::Clbit(clbit) => self.clbit_input_map.get(clbit).copied(), + Wire::Var(var) => self.var_input_map.get(var), }; while let Some(node) = current_node { @@ -5562,17 +5868,23 @@ impl DAGCircuit { nodes } - fn remove_idle_wire(&mut self, py: Python, wire: Wire) -> PyResult<()> { - let [in_node, out_node] = match wire { - Wire::Qubit(qubit) => self.qubit_io_map[qubit.0 as usize], - Wire::Clbit(clbit) => self.clbit_io_map[clbit.0 as usize], - Wire::Var(var) => [ - self.var_input_map.remove(py, &var).unwrap(), - self.var_output_map.remove(py, &var).unwrap(), - ], + fn remove_idle_wire(&mut self, wire: Wire) -> PyResult<()> { + let (in_node, out_node) = match wire { + Wire::Qubit(qubit) => ( + self.qubit_input_map.shift_remove(&qubit), + self.qubit_output_map.shift_remove(&qubit), + ), + Wire::Clbit(clbit) => ( + self.clbit_input_map.shift_remove(&clbit), + self.clbit_output_map.shift_remove(&clbit), + ), + Wire::Var(var) => ( + self.var_input_map.remove(&var), + self.var_output_map.remove(&var), + ), }; - self.dag.remove_node(in_node); - self.dag.remove_node(out_node); + self.dag.remove_node(in_node.unwrap()); + self.dag.remove_node(out_node.unwrap()); Ok(()) } @@ -5588,7 +5900,7 @@ impl DAGCircuit { }, )?, )?; - self.add_wire(py, Wire::Qubit(qubit))?; + self.add_wire(Wire::Qubit(qubit))?; Ok(qubit) } @@ -5604,11 +5916,11 @@ impl DAGCircuit { }, )?, )?; - self.add_wire(py, Wire::Clbit(clbit))?; + self.add_wire(Wire::Clbit(clbit))?; Ok(clbit) } - pub fn get_node(&self, py: Python, node: NodeIndex) -> PyResult> { + pub(crate) fn get_node(&self, py: Python, node: NodeIndex) -> PyResult> { self.unpack_into(py, node, self.dag.node_weight(node).unwrap()) } @@ -5638,7 +5950,7 @@ impl DAGCircuit { match self.dag.remove_node(index) { Some(NodeType::Operation(packed)) => { - let op_name = packed.op.name(); + let op_name = packed.op.name().to_string(); self.decrement_op(op_name); } _ => panic!("Must be called with valid operation node!"), @@ -5822,8 +6134,8 @@ impl DAGCircuit { /// Returns an iterator over a list layers of the `DAGCircuit``. pub fn multigraph_layers(&self, py: Python) -> impl Iterator> + '_ { - let mut first_layer: Vec<_> = self.qubit_io_map.iter().map(|x| x[0]).collect(); - first_layer.extend(self.clbit_io_map.iter().map(|x| x[0])); + let mut first_layer: Vec<_> = self.qubit_input_map.values().copied().collect(); + first_layer.extend(self.clbit_input_map.values().copied()); first_layer.extend(self.var_input_map.values(py)); // A DAG is by definition acyclical, therefore unwrapping the layer should never fail. layers(&self.dag, first_layer).map(|layer| match layer { @@ -5861,10 +6173,12 @@ impl DAGCircuit { node.index() ))); } + self.global_phase.add(py, &other.global_phase); // Add wire from pred to succ if no ops on mapped wire on ``other`` for (in_dag_wire, self_wire) in qubit_map.iter() { - let [input_node, out_node] = other.qubit_io_map[in_dag_wire.0 as usize]; + let input_node = other.qubit_input_map[in_dag_wire]; + let out_node = other.qubit_output_map[in_dag_wire]; if other.dag.find_edge(input_node, out_node).is_some() { let pred = self .dag @@ -5893,7 +6207,8 @@ impl DAGCircuit { } } for (in_dag_wire, self_wire) in clbit_map.iter() { - let [input_node, out_node] = other.clbit_io_map[in_dag_wire.0 as usize]; + let input_node = other.clbit_input_map[in_dag_wire]; + let out_node = other.clbit_output_map[in_dag_wire]; if other.dag.find_edge(input_node, out_node).is_some() { let pred = self .dag @@ -5967,7 +6282,7 @@ impl DAGCircuit { .collect(); new_inst.qubits = Interner::intern(&mut self.qargs_cache, new_qubit_indices)?; new_inst.clbits = Interner::intern(&mut self.cargs_cache, new_clbit_indices)?; - self.increment_op(new_inst.op.name()); + self.increment_op(new_inst.op.name().to_string()); } let new_index = self.dag.add_node(new_node); out_map.insert(old_index, new_index); @@ -5977,7 +6292,7 @@ impl DAGCircuit { if out_map.is_empty() { match self.dag.remove_node(node) { Some(NodeType::Operation(packed)) => { - let op_name = packed.op.name(); + let op_name = packed.op.name().to_string(); self.decrement_op(op_name); } _ => unreachable!("Must be called with valid operation node!"), @@ -6008,16 +6323,16 @@ impl DAGCircuit { for (source, _target, weight) in edges { let wire_input_id = match weight { Wire::Qubit(qubit) => other - .qubit_io_map - .get(reverse_qubit_map[&qubit].0 as usize) - .map(|x| x[0]), + .qubit_input_map + .get(&reverse_qubit_map[&qubit]) + .copied(), Wire::Clbit(clbit) => other - .clbit_io_map - .get(reverse_clbit_map[&clbit].0 as usize) - .map(|x| x[0]), + .clbit_input_map + .get(&reverse_clbit_map[&clbit]) + .copied(), Wire::Var(ref var) => { let index = &reverse_var_map.get_item(var)?.unwrap().unbind(); - other.var_input_map.get(py, index) + other.var_input_map.get(index) } }; let old_index = @@ -6044,16 +6359,16 @@ impl DAGCircuit { for (_source, target, weight) in edges { let wire_output_id = match weight { Wire::Qubit(qubit) => other - .qubit_io_map - .get(reverse_qubit_map[&qubit].0 as usize) - .map(|x| x[1]), + .qubit_output_map + .get(&reverse_qubit_map[&qubit]) + .copied(), Wire::Clbit(clbit) => other - .clbit_io_map - .get(reverse_clbit_map[&clbit].0 as usize) - .map(|x| x[1]), + .clbit_output_map + .get(&reverse_clbit_map[&clbit]) + .copied(), Wire::Var(ref var) => { let index = &reverse_var_map.get_item(var)?.unwrap().unbind(); - other.var_output_map.get(py, index) + other.var_output_map.get(index) } }; let old_index = @@ -6074,7 +6389,7 @@ impl DAGCircuit { } // Remove node if let NodeType::Operation(inst) = &self.dag[node] { - self.decrement_op(inst.op.name().to_string().as_str()); + self.decrement_op(inst.op.name().to_string()); } self.dag.remove_node(node); Ok(out_map) @@ -6104,10 +6419,8 @@ impl DAGCircuit { let out_index = self.dag.add_node(out_node); self.dag .add_edge(in_index, out_index, Wire::Var(var.clone().unbind())); - self.var_input_map - .insert(py, var.clone().unbind(), in_index); - self.var_output_map - .insert(py, var.clone().unbind(), out_index); + self.var_input_map.insert(var.clone().unbind(), in_index); + self.var_output_map.insert(var.clone().unbind(), out_index); self.vars_by_type[type_ as usize] .bind(py) .add(var.clone().unbind())?; @@ -6122,75 +6435,6 @@ impl DAGCircuit { ); Ok(()) } - - fn check_op_addition(&self, py: Python, inst: &PackedInstruction) -> PyResult<()> { - if let Some(condition) = inst.condition() { - self._check_condition(py, inst.op.name(), condition.bind(py))?; - } - - for b in self.qargs_cache.intern(inst.qubits) { - if self.qubit_io_map.len() - 1 < b.0 as usize { - return Err(DAGCircuitError::new_err(format!( - "qubit {} not found in output map", - self.qubits.get(*b).unwrap() - ))); - } - } - - for b in self.cargs_cache.intern(inst.clbits) { - if !self.clbit_io_map.len() - 1 < b.0 as usize { - return Err(DAGCircuitError::new_err(format!( - "clbit {} not found in output map", - self.clbits.get(*b).unwrap() - ))); - } - } - - if self.may_have_additional_wires(py, inst) { - let (clbits, vars) = self.additional_wires(py, inst.op.view(), inst.condition())?; - for b in clbits { - if !self.clbit_io_map.len() - 1 < b.0 as usize { - return Err(DAGCircuitError::new_err(format!( - "clbit {} not found in output map", - self.clbits.get(b).unwrap() - ))); - } - } - for v in vars { - if !self.var_output_map.contains_key(py, &v) { - return Err(DAGCircuitError::new_err(format!( - "var {} not found in output map", - v - ))); - } - } - } - Ok(()) - } -} - -/// Add to global phase. Global phase can only be Float or ParameterExpression so this -/// does not handle the full possibility of parameter values. -fn add_global_phase(py: Python, phase: &Param, other: &Param) -> PyResult { - Ok(match [phase, other] { - [Param::Float(a), Param::Float(b)] => Param::Float(a + b), - [Param::Float(a), Param::ParameterExpression(b)] => Param::ParameterExpression( - b.clone_ref(py) - .call_method1(py, intern!(py, "__radd__"), (*a,))?, - ), - [Param::ParameterExpression(a), Param::Float(b)] => Param::ParameterExpression( - a.clone_ref(py) - .call_method1(py, intern!(py, "__add__"), (*b,))?, - ), - [Param::ParameterExpression(a), Param::ParameterExpression(b)] => { - Param::ParameterExpression(a.clone_ref(py).call_method1( - py, - intern!(py, "__add__"), - (b,), - )?) - } - _ => panic!("Invalid global phase"), - }) } type SortKeyType<'a> = (&'a [Qubit], &'a [Clbit]); diff --git a/crates/circuit/src/dag_node.rs b/crates/circuit/src/dag_node.rs index ccae7a8c5d82..e435cd9b9343 100644 --- a/crates/circuit/src/dag_node.rs +++ b/crates/circuit/src/dag_node.rs @@ -11,7 +11,7 @@ // that they have been altered from the originals. #[cfg(feature = "cache_pygates")] -use std::cell::OnceCell; +use std::cell::RefCell; use std::hash::Hasher; use crate::circuit_instruction::{CircuitInstruction, OperationFromPython}; @@ -24,7 +24,6 @@ use approx::relative_eq; use rustworkx_core::petgraph::stable_graph::NodeIndex; use numpy::IntoPyArray; -use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyString, PyTuple}; use pyo3::{intern, IntoPy, PyObject, PyResult, ToPyObject}; @@ -49,32 +48,24 @@ impl DAGNode { impl DAGNode { #[new] #[pyo3(signature=(nid=-1))] - fn py_new(nid: isize) -> PyResult { - Ok(DAGNode { + fn py_new(nid: isize) -> Self { + DAGNode { node: match nid { -1 => None, - nid => { - let index: usize = match nid.try_into() { - Ok(index) => index, - Err(_) => { - return Err(PyValueError::new_err( - "Invalid node index, must be -1 or a non-negative integer", - )) - } - }; - Some(NodeIndex::new(index)) - } + nid => Some(NodeIndex::new(nid.try_into().unwrap())), }, - }) + } } - #[getter(_node_id)] - fn get_py_node_id(&self) -> isize { + #[allow(non_snake_case)] + #[getter] + fn get__node_id(&self) -> isize { self.py_nid() } - #[setter(_node_id)] - fn set_py_node_id(&mut self, nid: isize) { + #[allow(non_snake_case)] + #[setter] + fn set__node_id(&mut self, nid: isize) { self.node = match nid { -1 => None, nid => Some(NodeIndex::new(nid.try_into().unwrap())), @@ -136,7 +127,7 @@ impl DAGOpNode { params: py_op.params, extra_attrs: py_op.extra_attrs, #[cfg(feature = "cache_pygates")] - py_op: op.unbind().into(), + py_op: RefCell::new(Some(op.unbind())), }; Py::new( @@ -160,13 +151,6 @@ impl DAGOpNode { } fn __eq__(slf: PyRef, py: Python, other: &Bound) -> PyResult { - // This check is more restrictive by design as it's intended to replace - // object identitity for set/dict membership and not be a semantic equivalence - // check. We have an implementation of that as part of `DAGCircuit.__eq__` and - // this method is specifically to ensure nodes are the same. This means things - // like parameter equality are stricter to reject things like - // Param::Float(0.1) == Param::ParameterExpression(0.1) (if the expression was - // a python parameter equivalent to a bound value). let Ok(other) = other.downcast::() else { return Ok(false); }; @@ -185,33 +169,20 @@ impl DAGOpNode { return Ok(false); } let params_eq = if slf.instruction.operation.try_standard_gate().is_some() { - let mut params_eq = true; - for (a, b) in slf - .instruction + slf.instruction .params .iter() .zip(borrowed_other.instruction.params.iter()) - { - let res = match [a, b] { + .all(|(a, b)| match [a, b] { [Param::Float(float_a), Param::Float(float_b)] => { relative_eq!(float_a, float_b, max_relative = 1e-10) } [Param::ParameterExpression(param_a), Param::ParameterExpression(param_b)] => { - param_a.bind(py).eq(param_b)? + param_a.bind(py).eq(param_b).unwrap() } - [Param::Obj(param_a), Param::Obj(param_b)] => param_a.bind(py).eq(param_b)?, _ => false, - }; - if !res { - params_eq = false; - break; - } - } - params_eq + }) } else { - // We've already evaluated the parameters are equal here via the Python space equality - // check so if we're not comparing standard gates and we've reached this point we know - // the parameters are already equal. true }; @@ -235,7 +206,7 @@ impl DAGOpNode { mut instruction: CircuitInstruction, deepcopy: bool, ) -> PyResult { - let sort_key = instruction.qubits.bind(py).str().unwrap().into(); + let sort_key = py.None(); if deepcopy { instruction.operation = instruction.operation.py_deepcopy(py, None)?; #[cfg(feature = "cache_pygates")] @@ -499,7 +470,7 @@ impl DAGInNode { ( DAGInNode { wire, - sort_key: intern!(py, "[]").clone().into(), + sort_key: py.None(), }, DAGNode { node: Some(node) }, ) @@ -513,7 +484,7 @@ impl DAGInNode { Ok(( DAGInNode { wire, - sort_key: intern!(py, "[]").clone().into(), + sort_key: py.None(), }, DAGNode { node: None }, )) @@ -572,7 +543,7 @@ impl DAGOutNode { ( DAGOutNode { wire, - sort_key: intern!(py, "[]").clone().into(), + sort_key: py.None(), }, DAGNode { node: Some(node) }, ) @@ -586,7 +557,7 @@ impl DAGOutNode { Ok(( DAGOutNode { wire, - sort_key: intern!(py, "[]").clone().into(), + sort_key: py.None(), }, DAGNode { node: None }, )) diff --git a/crates/circuit/src/interner.rs b/crates/circuit/src/interner.rs index e19f56e87a7d..7f24c19fc8d6 100644 --- a/crates/circuit/src/interner.rs +++ b/crates/circuit/src/interner.rs @@ -17,7 +17,7 @@ use hashbrown::HashMap; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct Index(u32); impl IntoPy for Index { diff --git a/crates/circuit/src/operations.rs b/crates/circuit/src/operations.rs index ccf41d7eefb2..5fd6dfe6b367 100644 --- a/crates/circuit/src/operations.rs +++ b/crates/circuit/src/operations.rs @@ -37,7 +37,31 @@ pub enum Param { } impl Param { - pub fn eq(&self, py: Python, other: &Param) -> PyResult { + pub fn add(&self, py: Python, other: &Param) -> Param { + match [self, other] { + [Self::Float(a), Self::Float(b)] => Param::Float(a + b), + [Self::Float(a), Self::ParameterExpression(b)] => Param::ParameterExpression( + b.clone_ref(py) + .call_method1(py, intern!(py, "__radd__"), (*a,)) + .expect("Parameter expression addition failed"), + ), + [Self::ParameterExpression(a), Self::Float(b)] => Param::ParameterExpression( + a.clone_ref(py) + .call_method1(py, intern!(py, "__add__"), (*b,)) + .expect("Parameter expression addition failed"), + ), + [Self::ParameterExpression(a), Self::ParameterExpression(b)] => { + Param::ParameterExpression( + a.clone_ref(py) + .call_method1(py, intern!(py, "__add__"), (b,)) + .expect("Parameter expression addition failed"), + ) + } + _ => unreachable!(), + } + } + + pub fn eq(&self, other: &Param, py: Python) -> PyResult { match [self, other] { [Self::Float(a), Self::Float(b)] => Ok(a == b), [Self::Float(a), Self::ParameterExpression(b)] => b.bind(py).eq(a), @@ -51,10 +75,10 @@ impl Param { } } - pub fn is_close(&self, py: Python, other: &Param, max_relative: f64) -> PyResult { + pub fn is_close(&self, other: &Param, py: Python, max_relative: f64) -> PyResult { match [self, other] { [Self::Float(a), Self::Float(b)] => Ok(relative_eq!(a, b, max_relative = max_relative)), - _ => self.eq(py, other), + _ => self.eq(other, py), } } } diff --git a/qiskit/converters/dag_to_circuit.py b/qiskit/converters/dag_to_circuit.py index 47adee456380..d65d1da031f3 100644 --- a/qiskit/converters/dag_to_circuit.py +++ b/qiskit/converters/dag_to_circuit.py @@ -66,7 +66,7 @@ def dag_to_circuit(dag, copy_operations=True): ) for var in dag.iter_declared_vars(): circuit.add_uninitialized_var(var) - circuit.metadata = dag.metadata + circuit.metadata = dict(dag.metadata) circuit.calibrations = dag.calibrations for node in dag.topological_op_nodes(): diff --git a/qiskit/transpiler/passes/routing/sabre_swap.py b/qiskit/transpiler/passes/routing/sabre_swap.py index 9edd1ceee445..ea87359bc06f 100644 --- a/qiskit/transpiler/passes/routing/sabre_swap.py +++ b/qiskit/transpiler/passes/routing/sabre_swap.py @@ -388,7 +388,7 @@ def recurse(dest_dag, source_dag, result, root_logical_map, layout): the virtual qubit in the root source DAG that it is bound to.""" swap_map, node_order, node_block_results = result for node_id in node_order: - node = source_dag.node(node_id) + node = source_dag._get_node(node_id) if node_id in swap_map: apply_swaps(dest_dag, swap_map[node_id], layout) if not node.is_control_flow(): diff --git a/qiskit/transpiler/passes/routing/star_prerouting.py b/qiskit/transpiler/passes/routing/star_prerouting.py index 53bc971a268b..ac17ae84272e 100644 --- a/qiskit/transpiler/passes/routing/star_prerouting.py +++ b/qiskit/transpiler/passes/routing/star_prerouting.py @@ -330,7 +330,7 @@ def star_preroute(self, dag, blocks, processing_order): } def tie_breaker_key(node): - return processing_order_index_map.get(node, node.sort_key) + return processing_order_index_map.get(node, node.sort_key or "") rust_processing_order = _extract_nodes(dag.topological_op_nodes(key=tie_breaker_key), dag) diff --git a/qiskit/visualization/circuit/_utils.py b/qiskit/visualization/circuit/_utils.py index d933d38b5c4a..128d66ac6581 100644 --- a/qiskit/visualization/circuit/_utils.py +++ b/qiskit/visualization/circuit/_utils.py @@ -520,9 +520,6 @@ def _any_crossover(qubits, node, nodes): ) -_GLOBAL_NID = 0 - - class _LayerSpooler(list): """Manipulate list of layer dicts for _get_layered_instructions.""" diff --git a/test/python/dagcircuit/test_dagcircuit.py b/test/python/dagcircuit/test_dagcircuit.py index ef8050961066..73dbecda5b68 100644 --- a/test/python/dagcircuit/test_dagcircuit.py +++ b/test/python/dagcircuit/test_dagcircuit.py @@ -2695,7 +2695,8 @@ def test_substituting_node_preserves_args_condition(self, inplace): self.assertEqual(replacement_node.qargs, (qr[1], qr[0])) self.assertEqual(replacement_node.cargs, ()) self.assertEqual(replacement_node.op.condition, (cr, 1)) - self.assertEqual(replacement_node is node_to_be_replaced, inplace) + + self.assertNotEqual(replacement_node, node_to_be_replaced) @data(True, False) def test_substituting_node_preserves_parents_children(self, inplace): @@ -2720,7 +2721,7 @@ def test_substituting_node_preserves_parents_children(self, inplace): self.assertEqual(set(dag.successors(replacement_node)), successors) self.assertEqual(dag.ancestors(replacement_node), ancestors) self.assertEqual(dag.descendants(replacement_node), descendants) - self.assertEqual(replacement_node is node_to_be_replaced, inplace) + self.assertNotEqual(replacement_node, node_to_be_replaced) @data(True, False) def test_refuses_to_overwrite_condition(self, inplace):