diff --git a/_static/authors/josh_izaac.txt b/_static/authors/josh_izaac.txt index 5fa7412b87..3691993bea 100644 --- a/_static/authors/josh_izaac.txt +++ b/_static/authors/josh_izaac.txt @@ -1,4 +1,4 @@ -.. bio:: Josh Izaac - :photo: ../_static/authors/josh_izaac.png - +.. bio:: Josh Izaac + :photo: ../_static/authors/josh_izaac.png + Josh is a theoretical physicist, software tinkerer, and occasional baker. At Xanadu, he contributes to the development and growth of Xanadu’s open-source quantum software products. \ No newline at end of file diff --git a/_static/authors/maria_schuld.txt b/_static/authors/maria_schuld.txt index 31e1429cb5..1306ffd804 100644 --- a/_static/authors/maria_schuld.txt +++ b/_static/authors/maria_schuld.txt @@ -1,4 +1,4 @@ -.. bio:: Maria Schuld - :photo: ../_static/authors/maria_schuld.jpg - +.. bio:: Maria Schuld + :photo: ../_static/authors/maria_schuld.jpg + Maria leads Xanadu's quantum machine learning team and is a seasoned PennyLane developer. \ No newline at end of file diff --git a/_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_JAXopt/socialthumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png b/_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_JAXopt/socialthumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png new file mode 100644 index 0000000000..54f3038d81 Binary files /dev/null and b/_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_JAXopt/socialthumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png differ diff --git a/_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_Optax/socialsthumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png b/_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_Optax/socialsthumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png new file mode 100644 index 0000000000..7f03b48985 Binary files /dev/null and b/_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_Optax/socialsthumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png differ diff --git a/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png b/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png new file mode 100644 index 0000000000..67cdd6d4a6 Binary files /dev/null and b/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png differ diff --git a/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png b/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png new file mode 100644 index 0000000000..6411299736 Binary files /dev/null and b/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png differ diff --git a/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png b/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png new file mode 100644 index 0000000000..46c96e81d6 Binary files /dev/null and b/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png differ diff --git a/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png b/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png new file mode 100644 index 0000000000..b064fcd6dc Binary files /dev/null and b/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png differ diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt.metadata.json b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt.metadata.json new file mode 100644 index 0000000000..6227544071 --- /dev/null +++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt.metadata.json @@ -0,0 +1,51 @@ +{ + "title": "How to optimize a QML model using JAX and JAXopt", + "authors": [ + { + "id": "josh_izaac" + }, + { + "id": "maria_schuld" + } + ], + "dateOfPublication": "2024-01-18T00:00:00+00:00", + "dateOfLastModification": "2024-01-18T00:00:00+00:00", + "categories": [ + "Quantum Machine Learning", + "Optimization" + ], + "tags": ["how to"], + "previewImages": [ + { + "type": "thumbnail", + "uri": "/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png" + }, + { + "type": "large_thumbnail", + "uri": "/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png" + } + ], + "seoDescription": "Learn how to train a quantum machine learning model using PennyLane, JAX, and JAXopt.", + "doi": "", + "canonicalURL": "/qml/demos/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt", + "references": [], + "basedOnPapers": [], + "referencedByPapers": [], + "relatedContent": [ + { + "type": "demonstration", + "id": "tutorial_How_to_optimize_QML_model_using_JAX_and_Optax", + "weight": 1.0 + }, + { + "type": "demonstration", + "id": "tutorial_jax_transformations", + "weight": 1.0 + }, + { + "type": "demonstration", + "id": "tutorial_variational_classifier", + "weight": 1.0 + } + ] +} \ No newline at end of file diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt.py b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt.py new file mode 100644 index 0000000000..b68a2098a8 --- /dev/null +++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt.py @@ -0,0 +1,225 @@ +r"""How to optimize a QML model using JAX and JAXopt +===================================================================== + +Once you have set up a quantum machine learning model, data to train with, and +cost function to minimize as an objective, the next step is to **perform the optimization**. That is, +setting up a classical optimization loop to find a minimal value of your cost function. + +In this example, we’ll show you how to use `JAX `__, an +autodifferentiable machine learning framework, and `JAXopt `__, a suite +of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning model. + +.. figure:: ../_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_JAXopt/socialthumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png + :align: center + :width: 50% + +""" + +###################################################################### +# Set up your model, data, and cost +# --------------------------------- +# + +###################################################################### +# Here, we will create a simple QML model for our optimization. In particular: +# +# - We will embed our data through a series of rotation gates. +# - We will then have an ansatz of trainable rotation gates with parameters ``weights``; it is these +# values we will train to minimize our cost function. +# - We will train the QML model on ``data``, a ``(5, 4)`` array, and optimize the model to match +# target predictions given by ``target``. +# + +import pennylane as qml +import jax +from jax import numpy as jnp +import jaxopt + +jax.config.update("jax_platform_name", "cpu") + +n_wires = 5 +data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 +targets = jnp.array([-0.2, 0.4, 0.35, 0.2]) + +dev = qml.device("default.qubit", wires=n_wires) + +@qml.qnode(dev) +def circuit(data, weights): + """Quantum circuit ansatz""" + + # data embedding + for i in range(n_wires): + # data[i] will be of shape (4,); we are + # taking advantage of operation vectorization here + qml.RY(data[i], wires=i) + + # trainable ansatz + for i in range(n_wires): + qml.RX(weights[i, 0], wires=i) + qml.RY(weights[i, 1], wires=i) + qml.RX(weights[i, 2], wires=i) + qml.CNOT(wires=[i, (i + 1) % n_wires]) + + # we use a sum of local Z's as an observable since a + # local Z would only be affected by params on that qubit. + return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) + +def my_model(data, weights, bias): + return circuit(data, weights) + bias + +###################################################################### +# We will define a simple cost function that computes the overlap between model output and target +# data, and `just-in-time (JIT) compile `__ it: +# + +@jax.jit +def loss_fn(params, data, targets): + predictions = my_model(data, params["weights"], params["bias"]) + loss = jnp.sum((targets - predictions) ** 2 / len(data)) + return loss + +###################################################################### +# Note that the model above is just an example for demonstration – there are important considerations +# that must be taken into account when performing QML research, including methods for data embedding, +# circuit architecture, and cost function, in order to build models that may have use. This is still +# an active area of research; see our `demonstrations `__ for +# details. +# + +###################################################################### +# Initialize your parameters +# -------------------------- +# + +###################################################################### +# Now, we can generate our trainable parameters ``weights`` and ``bias`` that will be used to train +# our QML model. +# + +weights = jnp.ones([n_wires, 3]) +bias = jnp.array(0.) +params = {"weights": weights, "bias": bias} + +###################################################################### +# Plugging the trainable parameters, data, and target labels into our cost function, we can see the +# current loss as well as the parameter gradients: +# + +print(loss_fn(params, data, targets)) + +print(jax.grad(loss_fn)(params, data, targets)) + +###################################################################### +# Create the optimizer +# -------------------- +# + +###################################################################### +# We can now use JAXopt to create a gradient descent optimizer, and train our circuit. +# +# To do so, we first create a function that returns the loss value *and* the gradient value during +# training; this allows us to track and print out the loss during training within JAXopt’s internal +# optimization loop. +# + +def loss_and_grad(params, data, targets, print_training, i): + loss_val, grad_val = jax.value_and_grad(loss_fn)(params, data, targets) + + def print_fn(): + jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val) + + # if print_training=True, print the loss every 5 steps + jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None) + + return loss_val, grad_val + +###################################################################### +# Note that we use a couple of JAX specific functions here: +# +# - ``jax.lax.cond`` instead of a Python ``if`` statement +# - ``jax.debug.print`` instead of a Python ``print`` function +# +# These JAX compatible functions are needed because JAXopt will automatically JIT compile the +# optimizer update step. +# + +opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True) +opt_state = opt.init_state(params) + +for i in range(100): + params, opt_state = opt.update(params, opt_state, data, targets, True, i) + + +###################################################################### +# Jitting the optimization loop +# ----------------------------- +# + +###################################################################### +# In the above example, we JIT compiled our cost function ``loss_fn`` (and JAXopt automatically JIT compiled the `loss_and_grad` function behind the scenes). However, we can +# also JIT compile the entire optimization loop; this means that the for-loop around optimization is +# not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data +# transfer between Python and our JIT compiled cost function with each update step. +# + +@jax.jit +def optimization_jit(params, data, targets, print_training=False): + opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True) + opt_state = opt.init_state(params) + + def update(i, args): + params, opt_state = opt.update(*args, i) + return (params, opt_state, *args[2:]) + + args = (params, opt_state, data, targets, print_training) + (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update, args) + + return params + +###################################################################### +# Note that -- similar to above -- we use ``jax.lax.fori_loop`` and ``jax.lax.cond``, rather than a standard Python for loop +# and if statement, to allow the control flow to be JIT compatible. +# + +params = {"weights": weights, "bias": bias} +optimization_jit(params, data, targets, print_training=True) + +###################################################################### +# Appendix: Timing the two approaches +# ----------------------------------- +# +# We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire +# optimization loop) to explore the differences in performance: +# + +from timeit import repeat + +def optimization(params, data, targets): + opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True) + opt_state = opt.init_state(params) + + for i in range(100): + params, opt_state = opt.update(params, opt_state, data, targets, False, i) + + return params + +reps = 5 +num = 2 + +times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps) +result = min(times) / num + +print(f"Jitting just the cost (best of {reps}): {result} sec per loop") + +times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps) +result = min(times) / num + +print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop") + +###################################################################### +# In this example, JIT compiling the entire optimization loop +# is significantly more performant. +# +# About the authors +# ----------------- +# diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json new file mode 100644 index 0000000000..11e08015d1 --- /dev/null +++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json @@ -0,0 +1,51 @@ +{ + "title": "How to optimize a QML model using JAX and Optax", + "authors": [ + { + "id": "josh_izaac" + }, + { + "id": "maria_schuld" + } + ], + "dateOfPublication": "2024-01-18T00:00:00+00:00", + "dateOfLastModification": "2024-01-18T00:00:00+00:00", + "categories": [ + "Quantum Machine Learning", + "Optimization" + ], + "tags": ["how to"], + "previewImages": [ + { + "type": "thumbnail", + "uri": "/_static/demonstration_assets/regular_demo_thumbnails/thumbnail_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png" + }, + { + "type": "large_thumbnail", + "uri": "/_static/large_demo_thumbnails/thumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png" + } + ], + "seoDescription": "Learn how to train a quantum machine learning model using PennyLane, JAX, and Optax.", + "doi": "", + "canonicalURL": "/qml/demos/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax", + "references": [], + "basedOnPapers": [], + "referencedByPapers": [], + "relatedContent": [ + { + "type": "demonstration", + "id": "tutorial_How_to_optimize_QML_model_using_JAX_and_JAXopt", + "weight": 1.0 + }, + { + "type": "demonstration", + "id": "tutorial_jax_transformations", + "weight": 1.0 + }, + { + "type": "demonstration", + "id": "tutorial_variational_classifier", + "weight": 1.0 + } + ] +} \ No newline at end of file diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py new file mode 100644 index 0000000000..76b86f7847 --- /dev/null +++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py @@ -0,0 +1,239 @@ +r"""How to optimize a QML model using JAX and Optax +==================================================================== + +Once you have set up a quantum machine learning model, data to train with and +cost function to minimize as an objective, the next step is to **perform the optimization**. That is, +setting up a classical optimization loop to find a minimal value of your cost function. + +In this example, we’ll show you how to use `JAX `__, an +autodifferentiable machine learning framework, and `Optax `__, a +suite of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning +model. + +.. figure:: ../_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_Optax/socialsthumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png + :align: center + :width: 50% + +""" + +###################################################################### +# Set up your model, data, and cost +# --------------------------------- +# + +###################################################################### +# Here, we will create a simple QML model for our optimization. In particular: +# +# - We will embed our data through a series of rotation gates. +# - We will then have an ansatz of trainable rotation gates with parameters ``weights``; it is these +# values we will train to minimize our cost function. +# - We will train the QML model on ``data``, a ``(5, 4)`` array, and optimize the model to match +# target predictions given by ``target``. +# + +import pennylane as qml +import jax +from jax import numpy as jnp +import optax + +n_wires = 5 +data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 +targets = jnp.array([-0.2, 0.4, 0.35, 0.2]) + +dev = qml.device("default.qubit", wires=n_wires) + +@qml.qnode(dev) +def circuit(data, weights): + """Quantum circuit ansatz""" + + # data embedding + for i in range(n_wires): + # data[i] will be of shape (4,); we are + # taking advantage of operation vectorization here + qml.RY(data[i], wires=i) + + # trainable ansatz + for i in range(n_wires): + qml.RX(weights[i, 0], wires=i) + qml.RY(weights[i, 1], wires=i) + qml.RX(weights[i, 2], wires=i) + qml.CNOT(wires=[i, (i + 1) % n_wires]) + + # we use a sum of local Z's as an observable since a + # local Z would only be affected by params on that qubit. + return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) + +def my_model(data, weights, bias): + return circuit(data, weights) + bias + +###################################################################### +# We will define a simple cost function that computes the overlap between model output and target +# data, and `just-in-time (JIT) compile `__ it: +# + +@jax.jit +def loss_fn(params, data, targets): + predictions = my_model(data, params["weights"], params["bias"]) + loss = jnp.sum((targets - predictions) ** 2 / len(data)) + return loss + +###################################################################### +# Note that the model above is just an example for demonstration – there are important considerations +# that must be taken into account when performing QML research, including methods for data embedding, +# circuit architecture, and cost function, in order to build models that may have use. This is still +# an active area of research; see our `demonstrations `__ for +# details. +# + +###################################################################### +# Initialize your parameters +# -------------------------- +# + +###################################################################### +# Now, we can generate our trainable parameters ``weights`` and ``bias`` that will be used to train +# our QML model. +# + +weights = jnp.ones([n_wires, 3]) +bias = jnp.array(0.) +params = {"weights": weights, "bias": bias} + +###################################################################### +# Plugging the trainable parameters, data, and target labels into our cost function, we can see the +# current loss as well as the parameter gradients: +# + +print(loss_fn(params, data, targets)) + +print(jax.grad(loss_fn)(params, data, targets)) + +###################################################################### +# Create the optimizer +# -------------------- +# + +###################################################################### +# We can now use Optax to create an optimizer, and train our circuit. +# Here, we choose the Adam optimizer, however +# `other available optimizers `__ +# may be used here. +# + +opt = optax.adam(learning_rate=0.3) +opt_state = opt.init(params) + +###################################################################### +# We first define our ``update_step`` function, which needs to do a couple of things: +# +# - Compute the loss function (so we can track training) and the gradients (so we can apply an +# optimization step). We can do this in one execution via the ``jax.value_and_grad`` function. +# +# - Apply the update step of our optimizer via ``opt.update`` +# +# - Update the parameters via ``optax.apply_updates`` +# + +def update_step(params, opt_state, data, targets): + loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets) + updates, opt_state = opt.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return params, opt_state, loss_val + +loss_history = [] + +for i in range(100): + params, opt_state, loss_val = update_step(params, opt_state, data, targets) + + if i % 5 == 0: + print(f"Step: {i} Loss: {loss_val}") + + loss_history.append(loss_val) + +###################################################################### +# Jitting the optimization loop +# ----------------------------- +# + +###################################################################### +# In the above example, we JIT compiled our cost function ``loss_fn``. However, we can +# also JIT compile the entire optimization loop; this means that the for-loop around optimization is +# not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data +# transfer between Python and our JIT compiled cost function with each update step. +# + +@jax.jit +def update_step_jit(i, args): + params, opt_state, data, targets, print_training = args + + loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets) + updates, opt_state = opt.update(grads, opt_state) + params = optax.apply_updates(params, updates) + + def print_fn(): + jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val) + + # if print_training=True, print the loss every 5 steps + jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None) + + return (params, opt_state, data, targets, print_training) + +@jax.jit +def optimization_jit(params, data, targets, print_training=False): + opt = optax.adam(learning_rate=0.3) + opt_state = opt.init(params) + + args = (params, opt_state, data, targets, print_training) + (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args) + + return params + +###################################################################### +# Note that we use ``jax.lax.fori_loop`` and ``jax.lax.cond``, rather than a standard Python for loop +# and if statement, to allow the control flow to be JIT compatible. We also +# use ``jax.debug.print`` to allow printing to take place at function run-time, +# rather than compile-time. +# + +params = {"weights": weights, "bias": bias} +optimization_jit(params, data, targets, print_training=True) + +###################################################################### +# Appendix: Timing the two approaches +# ----------------------------------- +# +# We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire +# optimization loop) to explore the differences in performance: +# + +from timeit import repeat + +def optimization(params, data, targets): + opt = optax.adam(learning_rate=0.3) + opt_state = opt.init(params) + + for i in range(100): + params, opt_state, loss_val = update_step(params, opt_state, data, targets) + + return params + +reps = 5 +num = 2 + +times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps) +result = min(times) / num + +print(f"Jitting just the cost (best of {reps}): {result} sec per loop") + +times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps) +result = min(times) / num + +print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop") + + +###################################################################### +# In this example, JIT compiling the entire optimization loop +# is significantly more performant. +# +# About the authors +# -----------------