-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TrotterProduct
differentiability with parameter-shift
bug
#6432
base: master
Are you sure you want to change the base?
Conversation
Hello. You may have forgotten to update the changelog!
|
Out of curiosity, why did we have to manually set the grad_method? Isn't the inherited gard_method variable None? In which cases should we manually specify the grad_method argument? |
Hey @Jaybsoni, Good question. I tagged you on the relevant thread on Slack if you want to read the full conversation regarding issues #6333 and #6331. 😅 TLDR: It seems that |
Co-authored-by: Christina Lee <[email protected]>
tests/ops/functions/conftest.py
Outdated
@@ -57,7 +57,9 @@ | |||
(qml.s_prod(1.1, qml.RX(1.1, 0)), {}), | |||
(qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)), {}), | |||
(qml.ctrl(qml.RX(1.1, 0), 1), {}), | |||
(qml.exp(qml.PauliX(0), 1.1), {}), | |||
(qml.exp(qml.PauliX(0), 1.1), {"skip_differentiation": True}), | |||
(qml.exp(qml.PauliX(0), 2.9j), {}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be ok with commenting this line out till we figure out what to do with the generator of Exp
.
(qml.exp(qml.PauliX(0), 2.9j), {}), | |
#(qml.exp(qml.PauliX(0), 2.9j), {}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can xfail it and create an issue to track it?
TrotterProduct
differentiability with parameter-shift
TrotterProduct
differentiability with parameter-shift
bug
coeffs, _ = hamiltonian.terms() | ||
|
||
# FIXME: setting private attribute `_coeffs` as work around | ||
@qml.qnode(dev, diff_method="backprop") | ||
def circ_bp(coeffs, time): | ||
hamiltonian._coeffs = coeffs | ||
qml.TrotterProduct(hamiltonian, time, n, order) | ||
return qml.probs() | ||
|
||
@qml.qnode(dev, diff_method="parameter-shift") | ||
def circ_ps(coeffs, time): | ||
hamiltonian._coeffs = coeffs | ||
qml.TrotterProduct(hamiltonian, time, n, order) | ||
return qml.probs() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it fail if you construct the hamiltonian in the function directly? I think that is a more realistic test.
coeffs, _ = hamiltonian.terms() | |
# FIXME: setting private attribute `_coeffs` as work around | |
@qml.qnode(dev, diff_method="backprop") | |
def circ_bp(coeffs, time): | |
hamiltonian._coeffs = coeffs | |
qml.TrotterProduct(hamiltonian, time, n, order) | |
return qml.probs() | |
@qml.qnode(dev, diff_method="parameter-shift") | |
def circ_ps(coeffs, time): | |
hamiltonian._coeffs = coeffs | |
qml.TrotterProduct(hamiltonian, time, n, order) | |
return qml.probs() | |
coeffs, ops = hamiltonian.terms() | |
# FIXME: setting private attribute `_coeffs` as work around | |
@qml.qnode(dev, diff_method="backprop") | |
def circ_bp(coeffs, time): | |
with qml.queuing.QueuingManager.stop_recording(): | |
hamiltonian = qml.dot(coeffs, ops) | |
qml.TrotterProduct(hamiltonian, time, n, order) | |
return qml.probs() | |
@qml.qnode(dev, diff_method="parameter-shift") | |
def circ_ps(coeffs, time): | |
with qml.queuing.QueuingManager.stop_recording(): | |
hamiltonian = qml.dot(coeffs, ops) | |
qml.TrotterProduct(hamiltonian, time, n, order) | |
return qml.probs() |
@pytest.mark.parametrize("hamiltonian", test_hamiltonians) | ||
def test_standard_validity(self, hamiltonian): | ||
"""Test standard validity criteria using assert_valid.""" | ||
time, n, order = (4.2, 10, 4) | ||
op = qml.TrotterProduct(hamiltonian, time, n=n, order=order) | ||
qml.ops.functions.assert_valid(op) | ||
qml.ops.functions.assert_valid(op, skip_differentiation=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the assert valid test for differentiation fail? I would hope this bug fix would allow that test to pass?
Context:
Prior to this fix, differentiating
TrotterProduct
withdiff_method="parameter-shift"
returned zeros (which was inconsistent with results when usingdiff_method="backprop"
).Description of the Change:
Set
TrotterProduct.grad_method=None
andExp.grad_method=None
. This results in,Benefits: Gradient results are now consistent with
backprop
.Possible Drawbacks: None
Related GitHub Issues: Fixes #6333
[sc-74923]