diff --git a/demonstrations/tutorial_qubit_rotation.py b/demonstrations/tutorial_qubit_rotation.py index dab43df330..172bcaa475 100644 --- a/demonstrations/tutorial_qubit_rotation.py +++ b/demonstrations/tutorial_qubit_rotation.py @@ -82,32 +82,16 @@ # # Let's see how we can easily implement and optimize this circuit using PennyLane. # -# Importing PennyLane and JAX -# ----------------------------- +# Importing PennyLane and Jax +# ---------------------------- # -# The first thing we need to do is import PennyLane, as well as JAX. +# The first thing we need to do is import PennyLane and Jax, as well as the wrapped version +# of NumPy provided by PennyLane. import pennylane as qml -from jax import numpy as np - +import jax +import optax # Import optax for optimization -############################################################################## -# .. important:: -# -# When constructing a hybrid quantum/classical computational model with PennyLane, -# it is important to **always import NumPy from PennyLane**, not the standard NumPy! -# -# By importing the wrapped version of NumPy provided by PennyLane, you can combine -# the power of JAX with PennyLane: -# -# * continue to use the classical JAX functions and arrays you know and love -# * combine quantum functions (evaluated on quantum hardware/simulators) and -# classical functions (provided by JAX) -# * allow PennyLane to automatically calculate gradients of both classical and -# quantum functions - - -############################################################################## # Creating a device # ----------------- # @@ -169,13 +153,14 @@ # First, we need to define the quantum function that will be evaluated in the QNode: +@qml.qnode(dev1, interface="jax") def circuit(params): qml.RX(params[0], wires=0) qml.RY(params[1], wires=0) return qml.expval(qml.PauliZ(0)) -############################################################################## +################################################################################ # This is a simple circuit, matching the one described above. # Notice that the function ``circuit()`` is constructed as if it were any # other Python function; it accepts a positional argument ``params``, which may @@ -220,7 +205,7 @@ def circuit(params): return qml.expval(qml.PauliZ(0)) -############################################################################## +################################################################################ # Thus, our ``circuit()`` quantum function is now a :class:`~.pennylane.QNode`, which will run on # device ``dev1`` every time it is evaluated. # @@ -228,7 +213,7 @@ def circuit(params): print(circuit([0.54, 0.12])) -############################################################################## +################################################################################ # Calculating quantum gradients # ----------------------------- # @@ -245,13 +230,12 @@ def circuit(params): # partial derivatives) of ``circuit``. The gradient can be evaluated in the same # way as the original function: -dcircuit = qml.grad(circuit, argnum=0) +dcircuit = jax.grad(circuit, argnums=0) -############################################################################## +################################################################################ # The function :func:`~.pennylane.grad` itself **returns a function**, representing # the derivative of the QNode with respect to the argument specified in ``argnum``. -# In this case, the function ``circuit`` takes one argument (``params``), so we -# specify ``argnum=0``. Because the argument has two elements, the returned gradient +# In this case, the argument has two elements, so the returned gradient # is two-dimensional. We can then evaluate this gradient function at any point in the parameter space. print(dcircuit([0.54, 0.12])) @@ -266,21 +250,21 @@ def circuit(params): @qml.qnode(dev1, interface="jax") -def circuit2(phi1, phi2): - qml.RX(phi1, wires=0) - qml.RY(phi2, wires=0) +def circuit2(params): + qml.RX(params[0], wires=0) + qml.RY(params[1], wires=0) return qml.expval(qml.PauliZ(0)) ################################################################################ -# When we calculate the gradient for such a function, the usage of ``argnum`` -# will be slightly different. In this case, ``argnum=0`` will return the gradient -# with respect to only the first parameter (``phi1``), and ``argnum=1`` will give -# the gradient for ``phi2``. To get the gradient with respect to both parameters, -# we can use ``argnum=[0,1]``: +# When we calculate the gradient for such a function, the usage of ``argnums`` +# will be slightly different. In this case, ``argnums=0`` will return the gradient +# with respect to the first parameter (``params[0]``), and ``argnums=1`` will give +# the gradient for ``params[1]``. To get the gradient with respect to both parameters, +# we can use ``argnums=(0, 1)``: -dcircuit = qml.grad(circuit2, argnum=[0, 1]) -print(dcircuit(0.54, 0.12)) +dcircuit = jax.grad(circuit2, argnums=0) +print(dcircuit([0.54, 0.12])) ################################################################################ # Keyword arguments may also be used in your custom quantum function. PennyLane @@ -343,16 +327,21 @@ def cost(x): # :class:`~.pennylane.GradientDescentOptimizer` class: # initialise the optimizer -opt = qml.GradientDescentOptimizer(stepsize=0.4) +opt = optax.adam(0.1) # set the number of steps steps = 100 # set the initial parameter values params = init_params +opt_state = opt.init(init_params) for i in range(steps): + # compute the gradient of the cost + grads = dcircuit(params) + # update the circuit parameters - params = opt.step(cost, params) + updates, opt_state = opt.update(grads, opt_state) + params = optax.apply_updates(params, updates) if (i + 1) % 5 == 0: print("Cost after step {:5d}: {: .7f}".format(i + 1, cost(params)))