Skip to content

Commit

Permalink
Update tutorial_qubit_rotation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KetpuntoG committed Sep 13, 2023
1 parent 29dc909 commit c767cea
Showing 1 changed file with 30 additions and 41 deletions.
71 changes: 30 additions & 41 deletions demonstrations/tutorial_qubit_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,32 +82,16 @@
#
# Let's see how we can easily implement and optimize this circuit using PennyLane.
#
# Importing PennyLane and JAX
# -----------------------------
# Importing PennyLane and Jax
# ----------------------------
#
# The first thing we need to do is import PennyLane, as well as JAX.
# The first thing we need to do is import PennyLane and Jax, as well as the wrapped version
# of NumPy provided by PennyLane.

import pennylane as qml
from jax import numpy as np

import jax
import optax # Import optax for optimization

##############################################################################
# .. important::
#
# When constructing a hybrid quantum/classical computational model with PennyLane,
# it is important to **always import NumPy from PennyLane**, not the standard NumPy!
#
# By importing the wrapped version of NumPy provided by PennyLane, you can combine
# the power of JAX with PennyLane:
#
# * continue to use the classical JAX functions and arrays you know and love
# * combine quantum functions (evaluated on quantum hardware/simulators) and
# classical functions (provided by JAX)
# * allow PennyLane to automatically calculate gradients of both classical and
# quantum functions


##############################################################################
# Creating a device
# -----------------
#
Expand Down Expand Up @@ -169,13 +153,14 @@
# First, we need to define the quantum function that will be evaluated in the QNode:


@qml.qnode(dev1, interface="jax")
def circuit(params):
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=0)
return qml.expval(qml.PauliZ(0))


##############################################################################
################################################################################
# This is a simple circuit, matching the one described above.
# Notice that the function ``circuit()`` is constructed as if it were any
# other Python function; it accepts a positional argument ``params``, which may
Expand Down Expand Up @@ -220,15 +205,15 @@ def circuit(params):
return qml.expval(qml.PauliZ(0))


##############################################################################
################################################################################
# Thus, our ``circuit()`` quantum function is now a :class:`~.pennylane.QNode`, which will run on
# device ``dev1`` every time it is evaluated.
#
# To evaluate, we simply call the function with some appropriate numerical inputs:

print(circuit([0.54, 0.12]))

##############################################################################
################################################################################
# Calculating quantum gradients
# -----------------------------
#
Expand All @@ -245,13 +230,12 @@ def circuit(params):
# partial derivatives) of ``circuit``. The gradient can be evaluated in the same
# way as the original function:

dcircuit = qml.grad(circuit, argnum=0)
dcircuit = jax.grad(circuit, argnums=0)

##############################################################################
################################################################################
# The function :func:`~.pennylane.grad` itself **returns a function**, representing
# the derivative of the QNode with respect to the argument specified in ``argnum``.
# In this case, the function ``circuit`` takes one argument (``params``), so we
# specify ``argnum=0``. Because the argument has two elements, the returned gradient
# In this case, the argument has two elements, so the returned gradient
# is two-dimensional. We can then evaluate this gradient function at any point in the parameter space.

print(dcircuit([0.54, 0.12]))
Expand All @@ -266,21 +250,21 @@ def circuit(params):


@qml.qnode(dev1, interface="jax")
def circuit2(phi1, phi2):
qml.RX(phi1, wires=0)
qml.RY(phi2, wires=0)
def circuit2(params):
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=0)
return qml.expval(qml.PauliZ(0))


################################################################################
# When we calculate the gradient for such a function, the usage of ``argnum``
# will be slightly different. In this case, ``argnum=0`` will return the gradient
# with respect to only the first parameter (``phi1``), and ``argnum=1`` will give
# the gradient for ``phi2``. To get the gradient with respect to both parameters,
# we can use ``argnum=[0,1]``:
# When we calculate the gradient for such a function, the usage of ``argnums``
# will be slightly different. In this case, ``argnums=0`` will return the gradient
# with respect to the first parameter (``params[0]``), and ``argnums=1`` will give
# the gradient for ``params[1]``. To get the gradient with respect to both parameters,
# we can use ``argnums=(0, 1)``:

dcircuit = qml.grad(circuit2, argnum=[0, 1])
print(dcircuit(0.54, 0.12))
dcircuit = jax.grad(circuit2, argnums=0)
print(dcircuit([0.54, 0.12]))

################################################################################
# Keyword arguments may also be used in your custom quantum function. PennyLane
Expand Down Expand Up @@ -343,16 +327,21 @@ def cost(x):
# :class:`~.pennylane.GradientDescentOptimizer` class:

# initialise the optimizer
opt = qml.GradientDescentOptimizer(stepsize=0.4)
opt = optax.adam(0.1)

# set the number of steps
steps = 100
# set the initial parameter values
params = init_params
opt_state = opt.init(init_params)

for i in range(steps):
# compute the gradient of the cost
grads = dcircuit(params)

# update the circuit parameters
params = opt.step(cost, params)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)

if (i + 1) % 5 == 0:
print("Cost after step {:5d}: {: .7f}".format(i + 1, cost(params)))
Expand Down

0 comments on commit c767cea

Please sign in to comment.