Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Rename "DAG" to "CircuitSeq" #61

Merged
merged 2 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions CODE_STRUCTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,44 @@ In `src/quartz/context/rule_parser.h`
* `class RuleParser`: define the rules to write 3-qubit gates in each of the gate sets

In `src/quartz/parser/qasm_parser.h`
* `class QASMParser`: parse an input QASM file to Quartz's DAG representation
* `class QASMParser`: parse an input QASM file to Quartz's CircuitSeq representation

In `src/dag/dag.h`
* `class DAG`: a circuit represented in a DAG with all gates stored in an `std::vector` (which is a sequence representation)
In `src/quartz/circuitseq/circuitseq.h`
* `class CircuitSeq`: a circuit sequence with all gates stored in an `std::vector`

In `src/dag/dagnode.h`
* `class DAGNode`: a node in DAG corresponds to a wire in the circuit
In `src/quartz/circuitseq/circuitwire.h`
* `class CircuitWire`: a wire in the circuit sequence

In `src/dag/daghyperedge.h`
* `class DAGHyperEdge`: a hyperedge in DAG corresponds to a gate (or parameter expression) in the circuit
In `src/quartz/circuitseq/circuitgate.h`
* `class CircuitGate`: a gate (or parameter expression) in the circuit sequence

In `src/dataset/dataset.h`
In `src/quartz/dataset/dataset.h`
* `class Dataset`: a collection of circuits grouped by fingerprints

In `src/dataset/equivalence_set.h`
In `src/quartz/dataset/equivalence_set.h`
* `class EquivalenceClass`: an ECC
* `class EquivalenceSet`: an ECC set

In `src/generator/generator.h`
In `src/quartz/generator/generator.h`
* `class Generator`: the circuit generator
* `Generator::generate`: generate circuits for an unverified ECC set (then use `src/python/verify_equivalences.py` to get the ECC set)

In `src/math/matrix.h`
In `src/quartz/math/matrix.h`
* `class Matrix`: a complex square matrix

In `src/math/vector.h`
In `src/quartz/math/vector.h`
* `class Vector`: a complex vector

In `src/verifier/verifier.h`
In `src/quartz/verifier/verifier.h`
* `Verifier::redundant`: check if the circuit generated is redundant, i.e., having some slices not in the representative set

In `src/tasograph/tasograph.h`
In `src/quartz/tasograph/tasograph.h`
* `class Graph`: the circuit to be optimized
* `Graph::optimize`: use the search algorithm to optimize the circuit
* `Graph::context_shift`: shift the context of the circuit, e.g., changing the gate set

In `src/tasograph/substitution.h`
In `src/quartz/tasograph/substitution.h`
* `class GraphXfer`: a circuit transformation

In `src/test/gen_ecc_set.cpp`
In `src/quartz/test/gen_ecc_set.cpp`
* `gen_ecc_set`: a function to generate ECC sets with given gate set and hyperparameters
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,26 @@ After that, you need a `QASMParser` object to parse the input `qasm` file. You c
QASMParser qasm_parser(&src_ctx);
```

Now you can use the `QASMParser` object to load the circuit from the `qasm` file to a `DAG` object, as below:
Now you can use the `QASMParser` object to load the circuit from the `qasm` file to a `CircuitSeq` object, as below:

``` cpp
DAG *dag = nullptr;
if (!qasm_parser.load_qasm(input_fn, dag)) {
CircuitSeq *seq = nullptr;
if (!qasm_parser.load_qasm(input_fn, seq)) {
std::cout << "Parser failed" << std::endl;
}
```

After you have the circuit loaded into the `DAG` object, you can construct a `Graph` object from it. The `Graph` object is the final circuit representation used in our optimizer. You can construct it as below:
After you have the circuit loaded into the `CircuitSeq` object, you can construct a `Graph` object from it. The `Graph` object is the final circuit representation used in our optimizer. You can construct it as below:

``` cpp
Graph graph(&src_ctx, dag);
Graph graph(&src_ctx, seq);
```

#### Context shift

If the input gate set is different from your target gate set, you should consider using the `context_shift` APIs to shift the context constructed with the gate sets to a context constructed with the target gate set.

To shift the context, you should create three `Contxt` objects, one for input, one for target, and one for their union as below:
To shift the context, you should create three `Context` objects, one for input, one for target, and one for their union as below:

``` cpp
Context src_ctx({GateType::h, GateType::ccz, GateType::x, GateType::cx,
Expand Down
20 changes: 10 additions & 10 deletions python/quartz/_cython/CCore.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,22 @@ cdef extern from "context/context.h" namespace "quartz":

ctypedef Context* Context_ptr

cdef extern from "dag/dag.h" namespace "quartz":
cdef cppclass DAG:
DAG(int, int) except +
cdef extern from "circuitseq/circuitseq.h" namespace "quartz":
cdef cppclass CircuitSeq:
CircuitSeq(int, int) except +
int get_num_qubits() const
int get_num_input_parameters() const
int get_num_total_parameters() const
int get_num_internal_parameters() const
int get_num_gates() const

ctypedef DAG* DAG_ptr
ctypedef CircuitSeq* CircuitSeq_ptr

cdef extern from "tasograph/substitution.h" namespace "quartz":
cdef cppclass GraphXfer:
GraphXfer(Context_ptr, const DAG_ptr, const DAG_ptr) except +
GraphXfer(Context_ptr, const CircuitSeq_ptr, const CircuitSeq_ptr) except +
@staticmethod
GraphXfer* create_GraphXfer(Context_ptr,const DAG_ptr ,const DAG_ptr, bool no_increase)
GraphXfer* create_GraphXfer(Context_ptr,const CircuitSeq_ptr ,const CircuitSeq_ptr, bool no_increase)
int num_src_op()
int num_dst_op()
string src_str()
Expand All @@ -106,7 +106,7 @@ cdef extern from "tasograph/tasograph.h" namespace "quartz":
cdef extern from "tasograph/tasograph.h" namespace "quartz":
cdef cppclass Graph:
Graph(Context *) except +
Graph(Context *, const DAG *) except +
Graph(Context *, const CircuitSeq *) except +
bool xfer_appliable(GraphXfer *, Op) except +
shared_ptr[Graph] apply_xfer(GraphXfer *, Op, bool) except +
pair[shared_ptr[Graph], vector[int]] apply_xfer_and_track_node(GraphXfer *, Op, bool) except +
Expand Down Expand Up @@ -137,11 +137,11 @@ cdef extern from "dataset/equivalence_set.h" namespace "quartz":
EquivalenceSet() except +
int num_equivalence_classes() const
bool load_json(Context *, const string)
vector[vector[DAG_ptr]] get_all_equivalence_sets() except +
vector[vector[CircuitSeq_ptr]] get_all_equivalence_sets() except +


cdef extern from "parser/qasm_parser.h" namespace "quartz":
cdef cppclass QASMParser:
QASMParser(Context *)
bool load_qasm(const string &, DAG *&) except +
bool load_qasm_str(const string &, DAG *&) except +
bool load_qasm(const string &, CircuitSeq *&) except +
bool load_qasm_str(const string &, CircuitSeq *&) except +
10 changes: 5 additions & 5 deletions python/quartz/_cython/core.pyx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# distutils: language = c++

from CCore cimport (
DAG,
CircuitSeq,
CircuitSeq_ptr,
Context,
DAG_ptr,
Edge,
EquivalenceSet,
Gate,
Expand Down Expand Up @@ -147,18 +147,18 @@ cdef class PyGate:


cdef class PyDAG:
cdef DAG_ptr dag
cdef CircuitSeq_ptr dag

def __cinit__(self, *, int num_qubits=-1, int num_input_params=-1):
if num_qubits >= 0 and num_input_params >= 0:
self.dag = new DAG(num_qubits, num_input_params)
self.dag = new CircuitSeq(num_qubits, num_input_params)
else:
self.dag = NULL

def __dealloc__(self):
pass

cdef set_this(self, DAG_ptr dag_):
cdef set_this(self, CircuitSeq_ptr dag_):
self.dag = dag_
return self

Expand Down
2 changes: 1 addition & 1 deletion src/benchmark/nam_middle_circuits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void benchmark_nam(const std::string &circuit_name) {
auto xfer_pair = GraphXfer::ccz_cx_rz_xfer(&union_ctx);
// Load qasm file
QASMParser qasm_parser(&src_ctx);
DAG *dag = nullptr;
CircuitSeq *dag = nullptr;
if (!qasm_parser.load_qasm(circuit_path, dag)) {
std::cout << "Parser failed" << std::endl;
return;
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark/nam_small_circuits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void benchmark_nam(const std::string &circuit_name) {
auto xfer_pair = GraphXfer::ccz_cx_rz_xfer(&union_ctx);
// Load qasm file
QASMParser qasm_parser(&src_ctx);
DAG *dag = nullptr;
CircuitSeq *dag = nullptr;
if (!qasm_parser.load_qasm(circuit_path, dag)) {
std::cout << "Parser failed" << std::endl;
return;
Expand Down
18 changes: 9 additions & 9 deletions src/python/verifier/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def phase_shift(vec, lam):


def phase_shift_by_id(vec, dag, phase_shift_id, all_parameters):
# Warning: If DAG::hash() is modified, this function should be modified correspondingly.
# Warning: If CircuitSeq::hash() is modified, this function should be modified correspondingly.
dag_meta = dag[0]
num_total_params = dag_meta[meta_index_num_total_parameters]
if (
Expand Down Expand Up @@ -680,13 +680,13 @@ def find_equivalences(
num_hashtags = len(data)
num_dags = sum(len(dags) for dags in data.values())
num_potential_equivalences = num_dags - num_hashtags
# first process hashtags with only 1 DAG
# first process hashtags with only 1 CircuitSeq
for hashtag, dags in data.items():
if len(dags) == 1:
output_dict[hashtag + "_0"] = [dags[0]]
num_different_dags_with_same_hash[hashtag] = 1
print(
f"Processed {len(output_dict)} hash values that had only 1 DAG, now processing the remaining {len(data) - len(output_dict)} ones with 2 or more DAGs..."
f"Processed {len(output_dict)} hash values that had only 1 circuit sequence, now processing the remaining {len(data) - len(output_dict)} ones with 2 or more circuit sequences..."
)
# now process hashtags with >1 DAGs
with mp.Pool() as pool:
Expand Down Expand Up @@ -727,10 +727,10 @@ def find_equivalences(

# A map from other hashtags to corresponding phase shifts.
other_hashtags = defaultdict(dict)
# |other_hashtags[other_hash][None]| indicates that if it's possible that a DAG with |other_hash|
# is equivalent with a DAG with |hashtag| without phase shifts.
# |other_hashtags[other_hash][None]| indicates that if it's possible that a CircuitSeq with |other_hash|
# is equivalent with a CircuitSeq with |hashtag| without phase shifts.
# |other_hashtags[other_hash][phase_shift_id]| is a list of DAGs with |hashtag| that can be equivalent
# to a DAG with |other_hash| under phase shift |phase_shift_id|.
# to a CircuitSeq with |other_hash| under phase shift |phase_shift_id|.
assert len(dags) > 0
for dag in dags:
dag_meta = dag[0]
Expand All @@ -742,7 +742,7 @@ def find_equivalences(
# phase shift id is item[1]
assert isinstance(item, list)
assert len(item) == 2
# We need the exact parameter in |dag|, so we cannot use the representative DAG |dags[0]|.
# We need the exact parameter in |dag|, so we cannot use the representative CircuitSeq |dags[0]|.
other_hashtags[item[0]][item[1]] = other_hashtags[item[0]].get(
item[1], []
) + [dag]
Expand Down Expand Up @@ -798,7 +798,7 @@ def find_equivalences(
# Pruning: we only need to try each input parameter once.
input_param_tried = False
for dag in dag_list:
# Warning: If DAG::hash() is modified,
# Warning: If CircuitSeq::hash() is modified,
# the expression |is_fixed_for_all_dags| should be modified correspondingly.
is_fixed_for_all_dags = (
0
Expand All @@ -819,7 +819,7 @@ def find_equivalences(
input_param_tried = True
equivalent_called_2 += 1
possible_num_equivalences_under_phase_shift += 1
# |phase_shift_id[0]| is the DAG generating this phase shift id.
# |phase_shift_id[0]| is the CircuitSeq generating this phase shift id.
if equivalent(
dag,
other_dag,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "daghyperedge.h"
#include "dagnode.h"
#include "circuitgate.h"
#include "circuitwire.h"

namespace quartz {
int DAGHyperEdge::get_min_qubit_index() const {
int CircuitGate::get_min_qubit_index() const {
int result = -1;
for (auto &input_node : input_nodes) {
for (auto &input_node : input_wires) {
if (input_node->is_qubit() &&
(result == -1 || input_node->index < result)) {
result = input_node->index;
Expand Down
24 changes: 24 additions & 0 deletions src/quartz/circuitseq/circuitgate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "../gate/gate.h"
#include "../utils/utils.h"

#include <vector>

namespace quartz {

class CircuitWire;

/**
* A gate in the circuit.
* Stores the gate type, input and output information.
*/
class CircuitGate {
public:
int get_min_qubit_index() const;
std::vector<CircuitWire *> input_wires; // Include parameters!
std::vector<CircuitWire *> output_wires;

Gate *gate;
};
} // namespace quartz
Loading