Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix TrotterProduct differentiability with parameter-shift bug #6432

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

andrijapau
Copy link
Contributor

@andrijapau andrijapau commented Oct 22, 2024

Context:

Prior to this fix, differentiating TrotterProduct with diff_method="parameter-shift" returned zeros (which was inconsistent with results when using diff_method="backprop").

Description of the Change:

Set TrotterProduct.grad_method=None and Exp.grad_method=None. This results in,

@qml.qnode(dev, diff_method="parameter-shift")
def circ(time, coeffs):
    h = qml.dot(coeffs, [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(1)])
    qml.TrotterProduct(h, time, n=10, order=4)
    return qml.expval(qml.Hadamard(0))

>>> time = qml.numpy.array(4.2)
>>> coeffs = qml.numpy.array([2.0, 2.0, 2.0])
>>> qml.jacobian(circ)(time,coeffs)
(array(3.27988133), array([3.31863668, 3.56911412, 0.        ]))

Benefits: Gradient results are now consistent with backprop.

Possible Drawbacks: None

Related GitHub Issues: Fixes #6333

[sc-74923]

@andrijapau andrijapau marked this pull request as ready for review October 22, 2024 16:46
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@Jaybsoni
Copy link
Contributor

Out of curiosity, why did we have to manually set the grad_method? Isn't the inherited gard_method variable None? In which cases should we manually specify the grad_method argument?

@andrijapau
Copy link
Contributor Author

Hey @Jaybsoni,

Good question. I tagged you on the relevant thread on Slack if you want to read the full conversation regarding issues #6333 and #6331. 😅

TLDR: It seems that parameter-shift tries to update the parameters and fails to do so. Setting grad_method=None forces it to decompose the operator before differentiating it's data. This does suggest some areas of improvement for Operator.grad_method.

@andrijapau andrijapau self-assigned this Oct 23, 2024
@@ -57,7 +57,9 @@
(qml.s_prod(1.1, qml.RX(1.1, 0)), {}),
(qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)), {}),
(qml.ctrl(qml.RX(1.1, 0), 1), {}),
(qml.exp(qml.PauliX(0), 1.1), {}),
(qml.exp(qml.PauliX(0), 1.1), {"skip_differentiation": True}),
(qml.exp(qml.PauliX(0), 2.9j), {}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be ok with commenting this line out till we figure out what to do with the generator of Exp.

Suggested change
(qml.exp(qml.PauliX(0), 2.9j), {}),
#(qml.exp(qml.PauliX(0), 2.9j), {}),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can xfail it and create an issue to track it?

@andrijapau andrijapau requested a review from lillian542 October 31, 2024 17:09
@andrijapau andrijapau changed the title Fix TrotterProduct differentiability with parameter-shift Fix TrotterProduct differentiability with parameter-shift bug Nov 7, 2024
@andrijapau andrijapau requested a review from albi3ro November 8, 2024 16:42
Comment on lines +465 to +478
coeffs, _ = hamiltonian.terms()

# FIXME: setting private attribute `_coeffs` as work around
@qml.qnode(dev, diff_method="backprop")
def circ_bp(coeffs, time):
hamiltonian._coeffs = coeffs
qml.TrotterProduct(hamiltonian, time, n, order)
return qml.probs()

@qml.qnode(dev, diff_method="parameter-shift")
def circ_ps(coeffs, time):
hamiltonian._coeffs = coeffs
qml.TrotterProduct(hamiltonian, time, n, order)
return qml.probs()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fail if you construct the hamiltonian in the function directly? I think that is a more realistic test.

Suggested change
coeffs, _ = hamiltonian.terms()
# FIXME: setting private attribute `_coeffs` as work around
@qml.qnode(dev, diff_method="backprop")
def circ_bp(coeffs, time):
hamiltonian._coeffs = coeffs
qml.TrotterProduct(hamiltonian, time, n, order)
return qml.probs()
@qml.qnode(dev, diff_method="parameter-shift")
def circ_ps(coeffs, time):
hamiltonian._coeffs = coeffs
qml.TrotterProduct(hamiltonian, time, n, order)
return qml.probs()
coeffs, ops = hamiltonian.terms()
# FIXME: setting private attribute `_coeffs` as work around
@qml.qnode(dev, diff_method="backprop")
def circ_bp(coeffs, time):
with qml.queuing.QueuingManager.stop_recording():
hamiltonian = qml.dot(coeffs, ops)
qml.TrotterProduct(hamiltonian, time, n, order)
return qml.probs()
@qml.qnode(dev, diff_method="parameter-shift")
def circ_ps(coeffs, time):
with qml.queuing.QueuingManager.stop_recording():
hamiltonian = qml.dot(coeffs, ops)
qml.TrotterProduct(hamiltonian, time, n, order)
return qml.probs()

@pytest.mark.parametrize("hamiltonian", test_hamiltonians)
def test_standard_validity(self, hamiltonian):
"""Test standard validity criteria using assert_valid."""
time, n, order = (4.2, 10, 4)
op = qml.TrotterProduct(hamiltonian, time, n=n, order=order)
qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the assert valid test for differentiation fail? I would hope this bug fix would allow that test to pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jaybsoni , I'm slowly re-visiting this PR but the reason was due to Exp also having differentiability issues lol.

@andrijapau andrijapau changed the title Fix TrotterProduct differentiability with parameter-shift bug [WIP] Fix TrotterProduct differentiability with parameter-shift bug Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Differentiating Trotter with parameter-shift returns all 0s
3 participants