diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json
index 11e08015d1..75a0f251b4 100644
--- a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json
+++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json
@@ -9,7 +9,7 @@
         }
     ],
     "dateOfPublication": "2024-01-18T00:00:00+00:00",
-    "dateOfLastModification": "2024-01-18T00:00:00+00:00",
+    "dateOfLastModification": "2024-03-08T00:00:00+00:00",
     "categories": [
         "Quantum Machine Learning",
         "Optimization"
diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py
index 76b86f7847..8dfe1e66de 100644
--- a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py
+++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py
@@ -134,7 +134,7 @@ def loss_fn(params, data, targets):
 # -  Update the parameters via ``optax.apply_updates``
 #
 
-def update_step(params, opt_state, data, targets):
+def update_step(opt, params, opt_state, data, targets):
     loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
     updates, opt_state = opt.update(grads, opt_state)
     params = optax.apply_updates(params, updates)
@@ -143,7 +143,7 @@ def update_step(params, opt_state, data, targets):
 loss_history = []
 
 for i in range(100):
-    params, opt_state, loss_val = update_step(params, opt_state, data, targets)
+    params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)
 
     if i % 5 == 0:
         print(f"Step: {i} Loss: {loss_val}")
@@ -162,6 +162,9 @@ def update_step(params, opt_state, data, targets):
 # transfer between Python and our JIT compiled cost function with each update step.
 #
 
+# Define the optimizer we want to work with
+opt = optax.adam(learning_rate=0.3)
+
 @jax.jit
 def update_step_jit(i, args):
     params, opt_state, data, targets, print_training = args
@@ -180,9 +183,8 @@ def print_fn():
 
 @jax.jit
 def optimization_jit(params, data, targets, print_training=False):
-    opt = optax.adam(learning_rate=0.3)
-    opt_state = opt.init(params)
 
+    opt_state = opt.init(params)
     args = (params, opt_state, data, targets, print_training)
     (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)
 
@@ -213,7 +215,7 @@ def optimization(params, data, targets):
     opt_state = opt.init(params)
 
     for i in range(100):
-        params, opt_state, loss_val = update_step(params, opt_state, data, targets)
+        params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)
 
     return params