Skip to content

Commit

Permalink
[Refactor] Rename "DAG" to "CircuitSeq" (#61)
Browse files Browse the repository at this point in the history
* [Refactor] Rename "DAG" to "CircuitSeq"

* fix compilation error
  • Loading branch information
xumingkuan authored Nov 10, 2022
1 parent 8a29479 commit 2e13eb7
Show file tree
Hide file tree
Showing 50 changed files with 1,901 additions and 1,863 deletions.
32 changes: 16 additions & 16 deletions CODE_STRUCTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,44 @@ In `src/quartz/context/rule_parser.h`
* `class RuleParser`: define the rules to write 3-qubit gates in each of the gate sets

In `src/quartz/parser/qasm_parser.h`
* `class QASMParser`: parse an input QASM file to Quartz's DAG representation
* `class QASMParser`: parse an input QASM file to Quartz's CircuitSeq representation

In `src/dag/dag.h`
* `class DAG`: a circuit represented in a DAG with all gates stored in an `std::vector` (which is a sequence representation)
In `src/quartz/circuitseq/circuitseq.h`
* `class CircuitSeq`: a circuit sequence with all gates stored in an `std::vector`

In `src/dag/dagnode.h`
* `class DAGNode`: a node in DAG corresponds to a wire in the circuit
In `src/quartz/circuitseq/circuitwire.h`
* `class CircuitWire`: a wire in the circuit sequence

In `src/dag/daghyperedge.h`
* `class DAGHyperEdge`: a hyperedge in DAG corresponds to a gate (or parameter expression) in the circuit
In `src/quartz/circuitseq/circuitgate.h`
* `class CircuitGate`: a gate (or parameter expression) in the circuit sequence

In `src/dataset/dataset.h`
In `src/quartz/dataset/dataset.h`
* `class Dataset`: a collection of circuits grouped by fingerprints

In `src/dataset/equivalence_set.h`
In `src/quartz/dataset/equivalence_set.h`
* `class EquivalenceClass`: an ECC
* `class EquivalenceSet`: an ECC set

In `src/generator/generator.h`
In `src/quartz/generator/generator.h`
* `class Generator`: the circuit generator
* `Generator::generate`: generate circuits for an unverified ECC set (then use `src/python/verify_equivalences.py` to get the ECC set)

In `src/math/matrix.h`
In `src/quartz/math/matrix.h`
* `class Matrix`: a complex square matrix

In `src/math/vector.h`
In `src/quartz/math/vector.h`
* `class Vector`: a complex vector

In `src/verifier/verifier.h`
In `src/quartz/verifier/verifier.h`
* `Verifier::redundant`: check if the circuit generated is redundant, i.e., having some slices not in the representative set

In `src/tasograph/tasograph.h`
In `src/quartz/tasograph/tasograph.h`
* `class Graph`: the circuit to be optimized
* `Graph::optimize`: use the search algorithm to optimize the circuit
* `Graph::context_shift`: shift the context of the circuit, e.g., changing the gate set

In `src/tasograph/substitution.h`
In `src/quartz/tasograph/substitution.h`
* `class GraphXfer`: a circuit transformation

In `src/test/gen_ecc_set.cpp`
In `src/quartz/test/gen_ecc_set.cpp`
* `gen_ecc_set`: a function to generate ECC sets with given gate set and hyperparameters
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,26 @@ After that, you need a `QASMParser` object to parse the input `qasm` file. You c
QASMParser qasm_parser(&src_ctx);
```
Now you can use the `QASMParser` object to load the circuit from the `qasm` file to a `DAG` object, as below:
Now you can use the `QASMParser` object to load the circuit from the `qasm` file to a `CircuitSeq` object, as below:
``` cpp
DAG *dag = nullptr;
if (!qasm_parser.load_qasm(input_fn, dag)) {
CircuitSeq *seq = nullptr;
if (!qasm_parser.load_qasm(input_fn, seq)) {
std::cout << "Parser failed" << std::endl;
}
```

After you have the circuit loaded into the `DAG` object, you can construct a `Graph` object from it. The `Graph` object is the final circuit representation used in our optimizer. You can construct it as below:
After you have the circuit loaded into the `CircuitSeq` object, you can construct a `Graph` object from it. The `Graph` object is the final circuit representation used in our optimizer. You can construct it as below:

``` cpp
Graph graph(&src_ctx, dag);
Graph graph(&src_ctx, seq);
```
#### Context shift
If the input gate set is different from your target gate set, you should consider using the `context_shift` APIs to shift the context constructed with the gate sets to a context constructed with the target gate set.
To shift the context, you should create three `Contxt` objects, one for input, one for target, and one for their union as below:
To shift the context, you should create three `Context` objects, one for input, one for target, and one for their union as below:
``` cpp
Context src_ctx({GateType::h, GateType::ccz, GateType::x, GateType::cx,
Expand Down
20 changes: 10 additions & 10 deletions python/quartz/_cython/CCore.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,22 @@ cdef extern from "context/context.h" namespace "quartz":

ctypedef Context* Context_ptr

cdef extern from "dag/dag.h" namespace "quartz":
cdef cppclass DAG:
DAG(int, int) except +
cdef extern from "circuitseq/circuitseq.h" namespace "quartz":
cdef cppclass CircuitSeq:
CircuitSeq(int, int) except +
int get_num_qubits() const
int get_num_input_parameters() const
int get_num_total_parameters() const
int get_num_internal_parameters() const
int get_num_gates() const

ctypedef DAG* DAG_ptr
ctypedef CircuitSeq* CircuitSeq_ptr

cdef extern from "tasograph/substitution.h" namespace "quartz":
cdef cppclass GraphXfer:
GraphXfer(Context_ptr, const DAG_ptr, const DAG_ptr) except +
GraphXfer(Context_ptr, const CircuitSeq_ptr, const CircuitSeq_ptr) except +
@staticmethod
GraphXfer* create_GraphXfer(Context_ptr,const DAG_ptr ,const DAG_ptr, bool no_increase)
GraphXfer* create_GraphXfer(Context_ptr,const CircuitSeq_ptr ,const CircuitSeq_ptr, bool no_increase)
int num_src_op()
int num_dst_op()
string src_str()
Expand All @@ -106,7 +106,7 @@ cdef extern from "tasograph/tasograph.h" namespace "quartz":
cdef extern from "tasograph/tasograph.h" namespace "quartz":
cdef cppclass Graph:
Graph(Context *) except +
Graph(Context *, const DAG *) except +
Graph(Context *, const CircuitSeq *) except +
bool xfer_appliable(GraphXfer *, Op) except +
shared_ptr[Graph] apply_xfer(GraphXfer *, Op, bool) except +
pair[shared_ptr[Graph], vector[int]] apply_xfer_and_track_node(GraphXfer *, Op, bool) except +
Expand Down Expand Up @@ -137,11 +137,11 @@ cdef extern from "dataset/equivalence_set.h" namespace "quartz":
EquivalenceSet() except +
int num_equivalence_classes() const
bool load_json(Context *, const string)
vector[vector[DAG_ptr]] get_all_equivalence_sets() except +
vector[vector[CircuitSeq_ptr]] get_all_equivalence_sets() except +


cdef extern from "parser/qasm_parser.h" namespace "quartz":
cdef cppclass QASMParser:
QASMParser(Context *)
bool load_qasm(const string &, DAG *&) except +
bool load_qasm_str(const string &, DAG *&) except +
bool load_qasm(const string &, CircuitSeq *&) except +
bool load_qasm_str(const string &, CircuitSeq *&) except +
10 changes: 5 additions & 5 deletions python/quartz/_cython/core.pyx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# distutils: language = c++

from CCore cimport (
DAG,
CircuitSeq,
CircuitSeq_ptr,
Context,
DAG_ptr,
Edge,
EquivalenceSet,
Gate,
Expand Down Expand Up @@ -147,18 +147,18 @@ cdef class PyGate:


cdef class PyDAG:
cdef DAG_ptr dag
cdef CircuitSeq_ptr dag

def __cinit__(self, *, int num_qubits=-1, int num_input_params=-1):
if num_qubits >= 0 and num_input_params >= 0:
self.dag = new DAG(num_qubits, num_input_params)
self.dag = new CircuitSeq(num_qubits, num_input_params)
else:
self.dag = NULL

def __dealloc__(self):
pass

cdef set_this(self, DAG_ptr dag_):
cdef set_this(self, CircuitSeq_ptr dag_):
self.dag = dag_
return self

Expand Down
2 changes: 1 addition & 1 deletion src/benchmark/nam_middle_circuits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void benchmark_nam(const std::string &circuit_name) {
auto xfer_pair = GraphXfer::ccz_cx_rz_xfer(&union_ctx);
// Load qasm file
QASMParser qasm_parser(&src_ctx);
DAG *dag = nullptr;
CircuitSeq *dag = nullptr;
if (!qasm_parser.load_qasm(circuit_path, dag)) {
std::cout << "Parser failed" << std::endl;
return;
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark/nam_small_circuits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void benchmark_nam(const std::string &circuit_name) {
auto xfer_pair = GraphXfer::ccz_cx_rz_xfer(&union_ctx);
// Load qasm file
QASMParser qasm_parser(&src_ctx);
DAG *dag = nullptr;
CircuitSeq *dag = nullptr;
if (!qasm_parser.load_qasm(circuit_path, dag)) {
std::cout << "Parser failed" << std::endl;
return;
Expand Down
18 changes: 9 additions & 9 deletions src/python/verifier/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def phase_shift(vec, lam):


def phase_shift_by_id(vec, dag, phase_shift_id, all_parameters):
# Warning: If DAG::hash() is modified, this function should be modified correspondingly.
# Warning: If CircuitSeq::hash() is modified, this function should be modified correspondingly.
dag_meta = dag[0]
num_total_params = dag_meta[meta_index_num_total_parameters]
if (
Expand Down Expand Up @@ -680,13 +680,13 @@ def find_equivalences(
num_hashtags = len(data)
num_dags = sum(len(dags) for dags in data.values())
num_potential_equivalences = num_dags - num_hashtags
# first process hashtags with only 1 DAG
# first process hashtags with only 1 CircuitSeq
for hashtag, dags in data.items():
if len(dags) == 1:
output_dict[hashtag + "_0"] = [dags[0]]
num_different_dags_with_same_hash[hashtag] = 1
print(
f"Processed {len(output_dict)} hash values that had only 1 DAG, now processing the remaining {len(data) - len(output_dict)} ones with 2 or more DAGs..."
f"Processed {len(output_dict)} hash values that had only 1 circuit sequence, now processing the remaining {len(data) - len(output_dict)} ones with 2 or more circuit sequences..."
)
# now process hashtags with >1 DAGs
with mp.Pool() as pool:
Expand Down Expand Up @@ -727,10 +727,10 @@ def find_equivalences(

# A map from other hashtags to corresponding phase shifts.
other_hashtags = defaultdict(dict)
# |other_hashtags[other_hash][None]| indicates that if it's possible that a DAG with |other_hash|
# is equivalent with a DAG with |hashtag| without phase shifts.
# |other_hashtags[other_hash][None]| indicates that if it's possible that a CircuitSeq with |other_hash|
# is equivalent with a CircuitSeq with |hashtag| without phase shifts.
# |other_hashtags[other_hash][phase_shift_id]| is a list of DAGs with |hashtag| that can be equivalent
# to a DAG with |other_hash| under phase shift |phase_shift_id|.
# to a CircuitSeq with |other_hash| under phase shift |phase_shift_id|.
assert len(dags) > 0
for dag in dags:
dag_meta = dag[0]
Expand All @@ -742,7 +742,7 @@ def find_equivalences(
# phase shift id is item[1]
assert isinstance(item, list)
assert len(item) == 2
# We need the exact parameter in |dag|, so we cannot use the representative DAG |dags[0]|.
# We need the exact parameter in |dag|, so we cannot use the representative CircuitSeq |dags[0]|.
other_hashtags[item[0]][item[1]] = other_hashtags[item[0]].get(
item[1], []
) + [dag]
Expand Down Expand Up @@ -798,7 +798,7 @@ def find_equivalences(
# Pruning: we only need to try each input parameter once.
input_param_tried = False
for dag in dag_list:
# Warning: If DAG::hash() is modified,
# Warning: If CircuitSeq::hash() is modified,
# the expression |is_fixed_for_all_dags| should be modified correspondingly.
is_fixed_for_all_dags = (
0
Expand All @@ -819,7 +819,7 @@ def find_equivalences(
input_param_tried = True
equivalent_called_2 += 1
possible_num_equivalences_under_phase_shift += 1
# |phase_shift_id[0]| is the DAG generating this phase shift id.
# |phase_shift_id[0]| is the CircuitSeq generating this phase shift id.
if equivalent(
dag,
other_dag,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "daghyperedge.h"
#include "dagnode.h"
#include "circuitgate.h"
#include "circuitwire.h"

namespace quartz {
int DAGHyperEdge::get_min_qubit_index() const {
int CircuitGate::get_min_qubit_index() const {
int result = -1;
for (auto &input_node : input_nodes) {
for (auto &input_node : input_wires) {
if (input_node->is_qubit() &&
(result == -1 || input_node->index < result)) {
result = input_node->index;
Expand Down
24 changes: 24 additions & 0 deletions src/quartz/circuitseq/circuitgate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "../gate/gate.h"
#include "../utils/utils.h"

#include <vector>

namespace quartz {

class CircuitWire;

/**
* A gate in the circuit.
* Stores the gate type, input and output information.
*/
class CircuitGate {
public:
int get_min_qubit_index() const;
std::vector<CircuitWire *> input_wires; // Include parameters!
std::vector<CircuitWire *> output_wires;

Gate *gate;
};
} // namespace quartz
Loading

0 comments on commit 2e13eb7

Please sign in to comment.