diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json index 11e08015d1..75a0f251b4 100644 --- a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json +++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.metadata.json @@ -9,7 +9,7 @@ } ], "dateOfPublication": "2024-01-18T00:00:00+00:00", - "dateOfLastModification": "2024-01-18T00:00:00+00:00", + "dateOfLastModification": "2024-03-08T00:00:00+00:00", "categories": [ "Quantum Machine Learning", "Optimization" diff --git a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py index 76b86f7847..8dfe1e66de 100644 --- a/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py +++ b/demonstrations/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py @@ -134,7 +134,7 @@ def loss_fn(params, data, targets): # - Update the parameters via ``optax.apply_updates`` # -def update_step(params, opt_state, data, targets): +def update_step(opt, params, opt_state, data, targets): loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates) @@ -143,7 +143,7 @@ def update_step(params, opt_state, data, targets): loss_history = [] for i in range(100): - params, opt_state, loss_val = update_step(params, opt_state, data, targets) + params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets) if i % 5 == 0: print(f"Step: {i} Loss: {loss_val}") @@ -162,6 +162,9 @@ def update_step(params, opt_state, data, targets): # transfer between Python and our JIT compiled cost function with each update step. # +# Define the optimizer we want to work with +opt = optax.adam(learning_rate=0.3) + @jax.jit def update_step_jit(i, args): params, opt_state, data, targets, print_training = args @@ -180,9 +183,8 @@ def print_fn(): @jax.jit def optimization_jit(params, data, targets, print_training=False): - opt = optax.adam(learning_rate=0.3) - opt_state = opt.init(params) + opt_state = opt.init(params) args = (params, opt_state, data, targets, print_training) (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args) @@ -213,7 +215,7 @@ def optimization(params, data, targets): opt_state = opt.init(params) for i in range(100): - params, opt_state, loss_val = update_step(params, opt_state, data, targets) + params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets) return params