diff --git a/demonstrations/tutorial_QGAN.py b/demonstrations/tutorial_QGAN.py index 9700c95640..a592242ec8 100644 --- a/demonstrations/tutorial_QGAN.py +++ b/demonstrations/tutorial_QGAN.py @@ -213,7 +213,7 @@ def gen_cost(gen_weights): cost = lambda: disc_cost(disc_weights) for step in range(50): - opt.minimize(cost, disc_weights) + opt.minimize(cost, [disc_weights]) if step % 5 == 0: cost_val = cost().numpy() print("Step {}: cost = {}".format(step, cost_val)) @@ -243,7 +243,7 @@ def gen_cost(gen_weights): cost = lambda: gen_cost(gen_weights) for step in range(50): - opt.minimize(cost, gen_weights) + opt.minimize(cost, [gen_weights]) if step % 5 == 0: cost_val = cost().numpy() print("Step {}: cost = {}".format(step, cost_val))