Skip to content

Commit

Permalink
Update tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py (#1052)
Browse files Browse the repository at this point in the history
Solving issue #1045 

As indicated in the issue, the optimizer being used within the
`update_step_jit` function is not the one defined in the
`optimization_jit` function. I tried to pass `opt` but it gave problems
so I simply defined the optimizer outside.

In the case where jit is not being applied, I was able to pass the
optimizer, so I updated those functions.
Tagging @josh146 in case there is a better solution in this case

---------

Co-authored-by: Josh Izaac <[email protected]>
  • Loading branch information
KetpuntoG and josh146 authored Mar 11, 2024
1 parent 33b9587 commit 380b289
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}
],
"dateOfPublication": "2024-01-18T00:00:00+00:00",
"dateOfLastModification": "2024-01-18T00:00:00+00:00",
"dateOfLastModification": "2024-03-08T00:00:00+00:00",
"categories": [
"Quantum Machine Learning",
"Optimization"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def loss_fn(params, data, targets):
# - Update the parameters via ``optax.apply_updates``
#

def update_step(params, opt_state, data, targets):
def update_step(opt, params, opt_state, data, targets):
loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
Expand All @@ -143,7 +143,7 @@ def update_step(params, opt_state, data, targets):
loss_history = []

for i in range(100):
params, opt_state, loss_val = update_step(params, opt_state, data, targets)
params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)

if i % 5 == 0:
print(f"Step: {i} Loss: {loss_val}")
Expand All @@ -162,6 +162,9 @@ def update_step(params, opt_state, data, targets):
# transfer between Python and our JIT compiled cost function with each update step.
#

# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)

@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, print_training = args
Expand All @@ -180,9 +183,8 @@ def print_fn():

@jax.jit
def optimization_jit(params, data, targets, print_training=False):
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)

opt_state = opt.init(params)
args = (params, opt_state, data, targets, print_training)
(params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)

Expand Down Expand Up @@ -213,7 +215,7 @@ def optimization(params, data, targets):
opt_state = opt.init(params)

for i in range(100):
params, opt_state, loss_val = update_step(params, opt_state, data, targets)
params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)

return params

Expand Down

0 comments on commit 380b289

Please sign in to comment.