Skip to content

Commit

Permalink
Adding new feature wire_options to draw_mpl function (#6486)
Browse files Browse the repository at this point in the history
**Context:**
On the level of mid-scale algorithms, it might be nice to differentiate
between different "types" of wires, for example by coloring them
differently, or giving them distinct line styles. As an example, in
PennyLaneAI/qml#1185 on preparing matrix product
states, there are auxiliary bond qubits and physical qubits, and
coloring them differently would be a neat thing to do.

**Description of the Change:**
Update the output wire_options that could change the line style and
color for circuit output.
```python
        @qml.qnode(qml.device("default.qubit"))
        def node(x):
            for w in range(5):
                qml.Hadamard(w) 
            return qml.expval(qml.PauliZ(0) @ qml.PauliY(1))

        # Make all wires cyan and bold, 
        # except for wires 2 and 6, which are dashed and another color
        wire_options = {"color": "cyan", 
                        "linewidth": 5, 
                        2: {"linestyle": "--", "color": "red"}, 
                        6: {"linestyle": "--", "color": "orange"}
                    }
        _,ax  = qml.draw_mpl(node, wire_options=wire_options)(0.52)
```

**Benefits:**
When complicated sates and quantum circuits diagram are created, wires
could be marked with different selections.

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
#6165

---------

Co-authored-by: David Wierichs <[email protected]>
Co-authored-by: ringo-but-quantum <>
Co-authored-by: Astral Cai <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
4 people authored Nov 8, 2024
1 parent 34179a0 commit 73a617a
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 19 deletions.
Binary file added doc/_static/draw_mpl/per_wire_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/_static/tape_mpl/per_wire_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion doc/code/qml_drawer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ Currently Available Styles
+|pls|+|plw|+|skd|+
+-----+-----+-----+
+|sol|+|sod|+|def|+
+-----+-----+-----+
+-----+-----+-----+
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

<h3>Improvements 🛠</h3>

* Added support for the `wire_options` dictionary to customize wire line formatting in `qml.draw_mpl` circuit
visualizations, allowing global and per-wire customization with options like `color`, `linestyle`, and `linewidth`.
[(#6486)](https://github.com/PennyLaneAI/pennylane/pull/6486)

<h4>Capturing and representing hybrid programs</h4>

* `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits.
Expand Down Expand Up @@ -54,6 +58,7 @@

This release contains contributions from (in alphabetical order):

Shiwen An
Astral Cai,
Pietropaolo Frisoni,
Andrija Paurevic
30 changes: 26 additions & 4 deletions pennylane/drawer/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ def draw_mpl(
fontsize (float or str): fontsize for text. Valid strings are
``{'xx-small', 'x-small', 'small', 'medium', large', 'x-large', 'xx-large'}``.
Default is ``14``.
wire_options (dict): matplotlib formatting options for the wire lines
wire_options (dict): matplotlib formatting options for the wire lines. In addition to
standard options, options per wire can be specified with ``wire_label: options``
pairs, also see examples below.
label_options (dict): matplotlib formatting options for the wire labels
show_wire_labels (bool): Whether or not to show the wire labels.
active_wire_notches (bool): whether or not to add notches indicating active wires.
Expand Down Expand Up @@ -458,7 +460,8 @@ def circuit2(x, y):
**Wires:**
The keywords ``wire_order`` and ``show_all_wires`` control the location of wires from top to bottom.
The keywords ``wire_order`` and ``show_all_wires`` control the location of wires
from top to bottom.
.. code-block:: python
Expand All @@ -470,8 +473,8 @@ def circuit2(x, y):
:width: 60%
:target: javascript:void(0);
If a wire is in ``wire_order``, but not in the ``tape``, it will be omitted by default. Only by selecting
``show_all_wires=True`` will empty wires be displayed.
If a wire is in ``wire_order``, but not in the ``tape``, it will be omitted by default.
Only by selecting ``show_all_wires=True`` will empty wires be displayed.
.. code-block:: python
Expand Down Expand Up @@ -568,6 +571,25 @@ def circuit2(x, y):
:width: 60%
:target: javascript:void(0);
Additionally, ``wire_options`` may contain sub-dictionaries of matplotlib options assigned
to separate wire labels, which will control the line style for the respective individual wires.
.. code-block:: python
wire_options = {
'color': 'teal', # all wires but wire 2 will be teal
'linewidth': 5, # all wires but wire 2 will be bold
2: {'color': 'orange', 'linestyle': '--'}, # wire 2 will be orange and dashed
}
fig, ax = qml.draw_mpl(circuit, wire_options=wire_options)(1.2345,1.2345)
fig.show()
.. figure:: ../../_static/draw_mpl/per_wire_options.png
:align: center
:width: 60%
:target: javascript:void(0);
**Levels:**
The ``level`` keyword argument allows one to select a subset of the transforms to apply on the ``QNode``
Expand Down
24 changes: 19 additions & 5 deletions pennylane/drawer/mpldrawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,25 @@ def __init__(self, n_layers, n_wires, c_wires=0, wire_options=None, figsize=None
if wire_options is None:
wire_options = {}

# adding wire lines
self._wire_lines = [
plt.Line2D((-1, self.n_layers), (wire, wire), zorder=1, **wire_options)
for wire in range(self.n_wires)
]
# Separate global options from per wire options
global_options = {k: v for k, v in wire_options.items() if not isinstance(v, dict)}
wire_specific_options = {k: v for k, v in wire_options.items() if isinstance(v, dict)}

# Adding wire lines with individual styles based on wire_options
self._wire_lines = []
for wire in range(self.n_wires):
specific_options = wire_specific_options.get(wire, {})
line_options = {**global_options, **specific_options}

# Create Line2D with the combined options
line = plt.Line2D(
(-1, self.n_layers),
(wire, wire),
zorder=1,
**line_options,
)
self._wire_lines.append(line)

for line in self._wire_lines:
self._ax.add_line(line)

Expand Down
36 changes: 27 additions & 9 deletions pennylane/drawer/tape_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def tape_mpl(
fontsize (float or str): fontsize for text. Valid strings are
``{'xx-small', 'x-small', 'small', 'medium', large', 'x-large', 'xx-large'}``.
Default is ``14``.
wire_options (dict): matplotlib formatting options for the wire lines
wire_options (dict): matplotlib formatting options for the wire lines. In addition to
standard options, options per wire can be specified with ``wire_label: options``
pairs, also see examples below.
label_options (dict): matplotlib formatting options for the wire labels
show_wire_labels (bool): Whether or not to show the wire labels.
active_wire_notches (bool): whether or not to add notches indicating active wires.
Expand All @@ -328,7 +330,7 @@ def tape_mpl(
measurements = [qml.expval(qml.Z(0))]
tape = qml.tape.QuantumTape(ops, measurements)
fig, ax = tape_mpl(tape)
fig, ax = qml.drawer.tape_mpl(tape)
fig.show()
.. figure:: ../../_static/tape_mpl/default.png
Expand All @@ -350,7 +352,7 @@ def tape_mpl(
measurements = [qml.expval(qml.Z(0))]
tape2 = qml.tape.QuantumTape(ops, measurements)
fig, ax = tape_mpl(tape2, decimals=2)
fig, ax = qml.drawer.tape_mpl(tape2, decimals=2)
.. figure:: ../../_static/tape_mpl/decimals.png
:align: center
Expand All @@ -363,7 +365,7 @@ def tape_mpl(
.. code-block:: python
fig, ax = tape_mpl(tape, wire_order=[3,2,1,0])
fig, ax = qml.drawer.tape_mpl(tape, wire_order=[3,2,1,0])
.. figure:: ../../_static/tape_mpl/wire_order.png
:align: center
Expand All @@ -375,7 +377,7 @@ def tape_mpl(
.. code-block:: python
fig, ax = tape_mpl(tape, wire_order=["aux"], show_all_wires=True)
fig, ax = qml.drawer.tape_mpl(tape, wire_order=["aux"], show_all_wires=True)
.. figure:: ../../_static/tape_mpl/show_all_wires.png
:align: center
Expand All @@ -389,7 +391,7 @@ def tape_mpl(
.. code-block:: python
fig, ax = tape_mpl(tape)
fig, ax = qml.drawer.tape_mpl(tape)
fig.suptitle("My Circuit", fontsize="xx-large")
options = {'facecolor': "white", 'edgecolor': "#f57e7e", "linewidth": 6, "zorder": -1}
Expand All @@ -413,7 +415,7 @@ def tape_mpl(
.. code-block:: python
fig, ax = tape_mpl(tape, style='sketch')
fig, ax = qml.drawer.tape_mpl(tape, style='sketch')
.. figure:: ../../_static/tape_mpl/sketch_style.png
:align: center
Expand All @@ -437,7 +439,7 @@ def tape_mpl(
plt.rcParams['lines.linewidth'] = 5
plt.rcParams['figure.facecolor'] = 'ghostwhite'
fig, ax = tape_mpl(tape, style="rcParams")
fig, ax = qml.drawer.tape_mpl(tape, style="rcParams")
.. figure:: ../../_static/tape_mpl/rcparams.png
:align: center
Expand All @@ -450,14 +452,30 @@ def tape_mpl(
.. code-block:: python
fig, ax = tape_mpl(tape, wire_options={'color':'teal', 'linewidth': 5},
fig, ax = qml.drawer.tape_mpl(tape, wire_options={'color':'teal', 'linewidth': 5},
label_options={'size': 20})
.. figure:: ../../_static/tape_mpl/wires_labels.png
:align: center
:width: 60%
:target: javascript:void(0);
Additionally, ``wire_options`` may contain sub-dictionaries of matplotlib options assigned
to separate wire labels, which will control the line style for the respective individual wires.
.. code-block:: python
wire_options = {
'color': 'teal', # all wires but wire 2 will be teal
'linewidth': 5, # all wires but wire 2 will be bold
2: {'color': 'orange', 'linestyle': '--'}, # wire 2 will be orange and dashed
}
fig, ax = qml.drawer.tape_mpl(tape, wire_options=wire_options)
.. figure:: ../../_static/tape_mpl/per_wire_options.png
:align: center
:width: 60%
:target: javascript:void(0);
"""

restore_params = {}
Expand Down
69 changes: 69 additions & 0 deletions tests/drawer/test_draw_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,75 @@ def test_wire_options(self):
assert w.get_color() == "black"
assert w.get_linewidth() == 4

@qml.qnode(dev)
def f_circ(x):
"""Circuit on ten qubits."""
qml.RX(x, wires=0)
for w in range(10):
qml.Hadamard(w)
return qml.expval(qml.PauliZ(0) @ qml.PauliY(1))

# All wires are orange
wire_options = {"color": "orange"}
_, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52)

for w in ax.lines:
assert w.get_color() == "orange"

# Wires are orange and cyan
wire_options = {0: {"color": "orange"}, 1: {"color": "cyan"}}
_, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52)

assert ax.lines[0].get_color() == "orange"
assert ax.lines[1].get_color() == "cyan"
assert ax.lines[2].get_color() == "black"

# Make all wires cyan and bold,
# except for wires 2 and 6, which are dashed and another color
wire_options = {
"color": "cyan",
"linewidth": 5,
2: {"linestyle": "--", "color": "red"},
6: {"linestyle": "--", "color": "orange", "linewidth": 1},
}
_, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52)

for i, w in enumerate(ax.lines):
if i == 2:
assert w.get_color() == "red"
assert w.get_linestyle() == "--"
assert w.get_linewidth() == 5
elif i == 6:
assert w.get_color() == "orange"
assert w.get_linestyle() == "--"
assert w.get_linewidth() == 1
else:
assert w.get_color() == "cyan"
assert w.get_linestyle() == "-"
assert w.get_linewidth() == 5

wire_options = {
"linewidth": 5,
2: {"linestyle": "--", "color": "red"},
6: {"linestyle": "--", "color": "orange"},
}

_, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52)

for i, w in enumerate(ax.lines):
if i == 2:
assert w.get_color() == "red"
assert w.get_linestyle() == "--"
assert w.get_linewidth() == 5
elif i == 6:
assert w.get_color() == "orange"
assert w.get_linestyle() == "--"
assert w.get_linewidth() == 5
else:
assert w.get_color() == "black"
assert w.get_linestyle() == "-"
assert w.get_linewidth() == 5

plt.close()


Expand Down

0 comments on commit 73a617a

Please sign in to comment.