Skip to content

Commit

Permalink
Update split_non_commuting transform to accept any type of measuremen…
Browse files Browse the repository at this point in the history
…t (Issue PennyLaneAI#4932) (PennyLaneAI#4972)

**Context:**

Update the `split_non_commuting` transform to accept measurements which
accept both wires and observables i.e. `probs`, `sample` and `counts`

**Description of the Change:**

1. Modified the `split_non_commuting` transform in
[pennylane/transforms/split_non_commuting.py](https://github.com/PennyLaneAI/pennylane/blob/master/pennylane/transforms/split_non_commuting.py)
by adding a condition for `probs`, `sample` and `counts` and handling
`obs.wires` as a
`qml.PauliZ(wires[0])@qml.PauliZ(wires[1])@[email protected](wires[len(wires)-1])`
observable.
2. Updated the creation of new tapes `split_non_commuting` to handle
measurements initialized using wires
3. Added an entry under Improvements/Community contributions in
`changelog-dev.md`
4. Added unit, integration and automatic differentiation tests for all
interfaces in `tests/transforms/split_non_commuting.py`
5. Modified existing integration and automatic differentiation tests to
explicitly apply the transform on QNodes since it was previously applied
under the hood

**Benefits:**

Allows to split tapes into commuting groups for measurements initialized
using both wires and observables.

**Related GitHub Issues:**
PennyLaneAI#4932

---------

Co-authored-by: Mudit Pandey <[email protected]>
Co-authored-by: Romain Moyard <[email protected]>
  • Loading branch information
3 people committed Jan 3, 2024
1 parent 3fd3c0a commit 7242786
Show file tree
Hide file tree
Showing 3 changed files with 540 additions and 100 deletions.
7 changes: 6 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

* Update `tests/ops/functions/conftest.py` to ensure all operator types are tested for validity.
[(#4978)](https://github.com/PennyLaneAI/pennylane/pull/4978)

<h4>Community contributions 🥳</h4>

* The transform ``split_non_commuting`` now accepts measurements of type `probs`, `sample` and `counts` which accept both wires and observables. [(#4972)](https://github.com/PennyLaneAI/pennylane/pull/4972)

<h3>Breaking changes 💔</h3>

Expand All @@ -21,4 +25,5 @@

This release contains contributions from (in alphabetical order):

Matthew Silverman
Abhishek Abhishek,
Matthew Silverman.
84 changes: 60 additions & 24 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from functools import reduce

import pennylane as qml
from pennylane.measurements import ProbabilityMP, SampleMP

from pennylane.transforms import transform

Expand All @@ -34,13 +33,13 @@ def split_non_commuting(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.Quantu
non-commuting observables to measure.
Returns:
qnode (QNode) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`.
qnode (QNode) or tuple[List[QuantumTape], function]: The transformed circuit as described in
:func:`qml.transform <pennylane.transform>`.
**Example**
This transform allows us to transform a QNode that measures
non-commuting observables to *multiple* circuit executions
with qubit-wise commuting groups:
This transform allows us to transform a QNode that measures non-commuting observables to
*multiple* circuit executions with qubit-wise commuting groups:
.. code-block:: python3
Expand All @@ -52,7 +51,8 @@ def circuit(x):
qml.RX(x,wires=0)
return [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0))]
Instead of decorating the QNode, we can also create a new function that yields the same result in the following way:
Instead of decorating the QNode, we can also create a new function that yields the same result
in the following way:
.. code-block:: python3
Expand All @@ -70,8 +70,10 @@ def circuit(x):
\
0: ──RX(0.50)─┤ <Z>
Note that while internally multiple QNodes are created, the end result has the same ordering as the user provides in the return statement.
Here is a more involved example where we can see the different ordering at the execution level but restoring the original ordering in the output:
Note that while internally multiple QNodes are created, the end result has the same ordering as
the user provides in the return statement.
Here is a more involved example where we can see the different ordering at the execution level
but restoring the original ordering in the output:
.. code-block:: python3
Expand All @@ -95,8 +97,10 @@ def circuit0(x):
0: ──RY(0.79)──RX(0.79)─┤ <Z> ╭<Z@Z>
1: ─────────────────────┤ ╰<Z@Z>
Yet, executing it returns the original ordering of the expectation values. The outputs correspond to
:math:`(\langle \sigma_x^0 \rangle, \langle \sigma_z^0 \rangle, \langle \sigma_y^1 \rangle, \langle \sigma_z^0\sigma_z^1 \rangle)`.
Yet, executing it returns the original ordering of the expectation values. The outputs
correspond to
:math:`(\langle \sigma_x^0 \rangle, \langle \sigma_z^0 \rangle, \langle \sigma_y^1 \rangle,
\langle \sigma_z^0\sigma_z^1 \rangle)`.
>>> circuit0([np.pi/4, np.pi/4])
[0.7071067811865475, 0.49999999999999994, 0.0, 0.49999999999999994]
Expand All @@ -105,7 +109,8 @@ def circuit0(x):
.. details::
:title: Usage Details
Internally, this function works with tapes. We can create a tape with non-commuting observables:
Internally, this function works with tapes. We can create a tape with non-commuting
observables:
.. code-block:: python3
Expand All @@ -119,7 +124,8 @@ def circuit0(x):
>>> [t.observables for t in tapes]
[[expval(PauliZ(wires=[0]))], [expval(PauliY(wires=[0]))]]
The processing function becomes important when creating the commuting groups as the order of the inputs has been modified:
The processing function becomes important when creating the commuting groups as the order
of the inputs has been modified:
.. code-block:: python3
Expand All @@ -133,33 +139,63 @@ def circuit0(x):
tapes, processing_fn = qml.transforms.split_non_commuting(tape)
In this example, the groupings are ``group_coeffs = [[0,2], [1,3]]`` and ``processing_fn`` makes sure that the final output is of the same shape and ordering:
In this example, the groupings are ``group_coeffs = [[0,2], [1,3]]`` and ``processing_fn``
makes sure that the final output is of the same shape and ordering:
>>> processing_fn([t.measurements for t in tapes])
(expval(PauliZ(wires=[0]) @ PauliZ(wires=[1])),
expval(PauliX(wires=[0]) @ PauliX(wires=[1])),
expval(PauliZ(wires=[0])),
expval(PauliX(wires=[0])))
"""
Measurements that accept both observables and ``wires`` so that e.g. ``qml.counts``,
``qml.probs`` and ``qml.sample`` can also be used. When initialized using only ``wires``,
these measurements are interpreted as measuring with respect to the observable
``qml.PauliZ(wires[0])@qml.PauliZ(wires[1])@[email protected](wires[len(wires)-1])``
.. code-block:: python3
# TODO: allow for samples and probs
if any(isinstance(m, (SampleMP, ProbabilityMP)) for m in tape.measurements):
raise NotImplementedError(
"When non-commuting observables are used, only `qml.expval` and `qml.var` are supported."
)
measurements = [
qml.expval(qml.PauliX(0)),
qml.probs(wires=[1]),
qml.probs(wires=[0, 1])
]
tape = qml.tape.QuantumTape(measurements=measurements)
tapes, processing_fn = qml.transforms.split_non_commuting(tape)
This results in two tapes, each with commuting measurements:
>>> [t.measurements for t in tapes]
[[expval(PauliX(wires=[0])), probs(wires=[1])], [probs(wires=[0, 1])]]
"""

obs_list = tape.observables
# Construct a list of observables to group based on the measurements in the tape
obs_list = []
for obs in tape.observables:
# observable provided for a measurement
if isinstance(obs, qml.operation.Observable):
obs_list.append(obs)
# measurements using wires instead of observables
else:
# create the PauliZ tensor product observable when only wires are provided for the
# measurements
# TODO: Revisit when qml.prod is compatible with qml.pauli.group_observables
pauliz_obs = qml.PauliZ(obs.wires[0])
for wire in obs.wires[1:]:
pauliz_obs = pauliz_obs @ qml.PauliZ(wire)

obs_list.append(pauliz_obs)

# If there is more than one group of commuting observables, split tapes
groups, group_coeffs = qml.pauli.group_observables(obs_list, range(len(obs_list)))
if len(groups) > 1:
_, group_coeffs = qml.pauli.group_observables(obs_list, range(len(obs_list)))
if len(group_coeffs) > 1:
# make one tape per commuting group
tapes = []
for group, indices in zip(groups, group_coeffs):
for indices in group_coeffs:
new_tape = tape.__class__(
tape.operations,
(tape.measurements[i].__class__(obs=o) for o, i in zip(group, indices)),
(tape.measurements[i] for i in indices),
)

tapes.append(new_tape)
Expand Down
Loading

0 comments on commit 7242786

Please sign in to comment.