diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4d252aed799..4d73c06c314 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -8,6 +8,10 @@
* Update `tests/ops/functions/conftest.py` to ensure all operator types are tested for validity.
[(#4978)](https://github.com/PennyLaneAI/pennylane/pull/4978)
+
+
Community contributions 🥳
+
+* The transform ``split_non_commuting`` now accepts measurements of type `probs`, `sample` and `counts` which accept both wires and observables. [(#4972)](https://github.com/PennyLaneAI/pennylane/pull/4972)
Breaking changes 💔
@@ -21,4 +25,5 @@
This release contains contributions from (in alphabetical order):
-Matthew Silverman
+Abhishek Abhishek,
+Matthew Silverman.
\ No newline at end of file
diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py
index 06719ff8013..3dec7b19c5a 100644
--- a/pennylane/transforms/split_non_commuting.py
+++ b/pennylane/transforms/split_non_commuting.py
@@ -19,7 +19,6 @@
from functools import reduce
import pennylane as qml
-from pennylane.measurements import ProbabilityMP, SampleMP
from pennylane.transforms import transform
@@ -34,13 +33,13 @@ def split_non_commuting(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.Quantu
non-commuting observables to measure.
Returns:
- qnode (QNode) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform `.
+ qnode (QNode) or tuple[List[QuantumTape], function]: The transformed circuit as described in
+ :func:`qml.transform `.
**Example**
- This transform allows us to transform a QNode that measures
- non-commuting observables to *multiple* circuit executions
- with qubit-wise commuting groups:
+ This transform allows us to transform a QNode that measures non-commuting observables to
+ *multiple* circuit executions with qubit-wise commuting groups:
.. code-block:: python3
@@ -52,7 +51,8 @@ def circuit(x):
qml.RX(x,wires=0)
return [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0))]
- Instead of decorating the QNode, we can also create a new function that yields the same result in the following way:
+ Instead of decorating the QNode, we can also create a new function that yields the same result
+ in the following way:
.. code-block:: python3
@@ -70,8 +70,10 @@ def circuit(x):
\
0: ──RX(0.50)─┤
- Note that while internally multiple QNodes are created, the end result has the same ordering as the user provides in the return statement.
- Here is a more involved example where we can see the different ordering at the execution level but restoring the original ordering in the output:
+ Note that while internally multiple QNodes are created, the end result has the same ordering as
+ the user provides in the return statement.
+ Here is a more involved example where we can see the different ordering at the execution level
+ but restoring the original ordering in the output:
.. code-block:: python3
@@ -95,8 +97,10 @@ def circuit0(x):
0: ──RY(0.79)──RX(0.79)─┤ â•
1: ─────────────────────┤ ╰
- Yet, executing it returns the original ordering of the expectation values. The outputs correspond to
- :math:`(\langle \sigma_x^0 \rangle, \langle \sigma_z^0 \rangle, \langle \sigma_y^1 \rangle, \langle \sigma_z^0\sigma_z^1 \rangle)`.
+ Yet, executing it returns the original ordering of the expectation values. The outputs
+ correspond to
+ :math:`(\langle \sigma_x^0 \rangle, \langle \sigma_z^0 \rangle, \langle \sigma_y^1 \rangle,
+ \langle \sigma_z^0\sigma_z^1 \rangle)`.
>>> circuit0([np.pi/4, np.pi/4])
[0.7071067811865475, 0.49999999999999994, 0.0, 0.49999999999999994]
@@ -105,7 +109,8 @@ def circuit0(x):
.. details::
:title: Usage Details
- Internally, this function works with tapes. We can create a tape with non-commuting observables:
+ Internally, this function works with tapes. We can create a tape with non-commuting
+ observables:
.. code-block:: python3
@@ -119,7 +124,8 @@ def circuit0(x):
>>> [t.observables for t in tapes]
[[expval(PauliZ(wires=[0]))], [expval(PauliY(wires=[0]))]]
- The processing function becomes important when creating the commuting groups as the order of the inputs has been modified:
+ The processing function becomes important when creating the commuting groups as the order
+ of the inputs has been modified:
.. code-block:: python3
@@ -133,7 +139,8 @@ def circuit0(x):
tapes, processing_fn = qml.transforms.split_non_commuting(tape)
- In this example, the groupings are ``group_coeffs = [[0,2], [1,3]]`` and ``processing_fn`` makes sure that the final output is of the same shape and ordering:
+ In this example, the groupings are ``group_coeffs = [[0,2], [1,3]]`` and ``processing_fn``
+ makes sure that the final output is of the same shape and ordering:
>>> processing_fn([t.measurements for t in tapes])
(expval(PauliZ(wires=[0]) @ PauliZ(wires=[1])),
@@ -141,25 +148,54 @@ def circuit0(x):
expval(PauliZ(wires=[0])),
expval(PauliX(wires=[0])))
- """
+ Measurements that accept both observables and ``wires`` so that e.g. ``qml.counts``,
+ ``qml.probs`` and ``qml.sample`` can also be used. When initialized using only ``wires``,
+ these measurements are interpreted as measuring with respect to the observable
+ ``qml.PauliZ(wires[0])@qml.PauliZ(wires[1])@...@qml.PauliZ(wires[len(wires)-1])``
+
+ .. code-block:: python3
- # TODO: allow for samples and probs
- if any(isinstance(m, (SampleMP, ProbabilityMP)) for m in tape.measurements):
- raise NotImplementedError(
- "When non-commuting observables are used, only `qml.expval` and `qml.var` are supported."
- )
+ measurements = [
+ qml.expval(qml.PauliX(0)),
+ qml.probs(wires=[1]),
+ qml.probs(wires=[0, 1])
+ ]
+ tape = qml.tape.QuantumTape(measurements=measurements)
+
+ tapes, processing_fn = qml.transforms.split_non_commuting(tape)
+
+ This results in two tapes, each with commuting measurements:
+
+ >>> [t.measurements for t in tapes]
+ [[expval(PauliX(wires=[0])), probs(wires=[1])], [probs(wires=[0, 1])]]
+ """
- obs_list = tape.observables
+ # Construct a list of observables to group based on the measurements in the tape
+ obs_list = []
+ for obs in tape.observables:
+ # observable provided for a measurement
+ if isinstance(obs, qml.operation.Observable):
+ obs_list.append(obs)
+ # measurements using wires instead of observables
+ else:
+ # create the PauliZ tensor product observable when only wires are provided for the
+ # measurements
+ # TODO: Revisit when qml.prod is compatible with qml.pauli.group_observables
+ pauliz_obs = qml.PauliZ(obs.wires[0])
+ for wire in obs.wires[1:]:
+ pauliz_obs = pauliz_obs @ qml.PauliZ(wire)
+
+ obs_list.append(pauliz_obs)
# If there is more than one group of commuting observables, split tapes
- groups, group_coeffs = qml.pauli.group_observables(obs_list, range(len(obs_list)))
- if len(groups) > 1:
+ _, group_coeffs = qml.pauli.group_observables(obs_list, range(len(obs_list)))
+ if len(group_coeffs) > 1:
# make one tape per commuting group
tapes = []
- for group, indices in zip(groups, group_coeffs):
+ for indices in group_coeffs:
new_tape = tape.__class__(
tape.operations,
- (tape.measurements[i].__class__(obs=o) for o, i in zip(group, indices)),
+ (tape.measurements[i] for i in indices),
)
tapes.append(new_tape)
diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py
index fb0399e48ba..4b2ae1d7368 100644
--- a/tests/transforms/test_split_non_commuting.py
+++ b/tests/transforms/test_split_non_commuting.py
@@ -13,6 +13,7 @@
# limitations under the License.
""" Tests for the transform ``qml.transform.split_non_commuting()`` """
# pylint: disable=no-self-use, import-outside-toplevel, no-member, import-error
+import itertools
import pytest
import numpy as np
import pennylane as qml
@@ -20,67 +21,122 @@
from pennylane.transforms import split_non_commuting
-### example tape with 3 commuting groups [[0,3],[1,4],[2,5]]
-with qml.queuing.AnnotatedQueue() as q3:
- qml.PauliZ(0)
- qml.Hadamard(0)
- qml.CNOT((0, 1))
- qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
- qml.expval(qml.PauliX(0) @ qml.PauliX(1))
- qml.expval(qml.PauliY(0) @ qml.PauliY(1))
- qml.expval(qml.PauliZ(0))
- qml.expval(qml.PauliX(0))
- qml.expval(qml.PauliY(0))
-
-non_commuting_tape3 = qml.tape.QuantumScript.from_queue(q3)
-### example tape with 2 -commuting groups [[0,2],[1,3]]
-with qml.queuing.AnnotatedQueue() as q2:
- qml.PauliZ(0)
- qml.Hadamard(0)
- qml.CNOT((0, 1))
- qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
- qml.expval(qml.PauliX(0) @ qml.PauliX(1))
- qml.expval(qml.PauliZ(0))
- qml.expval(qml.PauliX(0))
-
-non_commuting_tape2 = qml.tape.QuantumScript.from_queue(q2)
-# For testing different observable types
-obs_fn = [qml.expval, qml.var]
+# list of observables with 2 commuting groups [[1, 3], [0, 2, 4]]
+obs_list_2 = [
+ qml.PauliZ(0) @ qml.PauliZ(1),
+ qml.PauliX(0) @ qml.PauliX(1),
+ qml.PauliZ(0),
+ qml.PauliX(0),
+ qml.PauliZ(1),
+]
+
+# list of observables with 3 commuting groups [[0,3], [1,4], [2,5]]
+obs_list_3 = [
+ qml.PauliZ(0) @ qml.PauliZ(1),
+ qml.PauliX(0) @ qml.PauliX(1),
+ qml.PauliY(0) @ qml.PauliY(1),
+ qml.PauliZ(0),
+ qml.PauliX(0),
+ qml.PauliY(0),
+]
+
+# measurements that accept observables as arguments
+obs_meas_fn = [qml.expval, qml.var, qml.probs, qml.counts, qml.sample]
+
+# measurements that accept wires as arguments
+wire_meas_fn = [qml.probs, qml.counts, qml.sample]
class TestUnittestSplitNonCommuting:
"""Unit tests on ``qml.transforms.split_non_commuting()``"""
- def test_commuting_group_no_split(self, mocker):
- """Testing that commuting groups are not split"""
+ @pytest.mark.parametrize("meas_type", obs_meas_fn)
+ def test_commuting_group_no_split(self, mocker, meas_type):
+ """Testing that commuting groups are not split for all supported measurement types"""
with qml.queuing.AnnotatedQueue() as q:
qml.PauliZ(0)
qml.Hadamard(0)
qml.CNOT((0, 1))
- qml.expval(qml.PauliZ(0))
- qml.expval(qml.PauliZ(0))
- qml.expval(qml.PauliX(1))
- qml.expval(qml.PauliZ(2))
- qml.expval(qml.PauliZ(0) @ qml.PauliZ(3))
+ meas_type(op=qml.PauliZ(0))
+ meas_type(op=qml.PauliZ(0))
+ meas_type(op=qml.PauliX(1))
+ meas_type(op=qml.PauliZ(2))
+ meas_type(op=qml.PauliZ(0) @ qml.PauliZ(3))
+ # test transform on tape
tape = qml.tape.QuantumScript.from_queue(q)
split, fn = split_non_commuting(tape)
spy = mocker.spy(qml.math, "concatenate")
+ assert len(split) == 1
assert all(isinstance(t, qml.tape.QuantumScript) for t in split)
assert fn([0.5]) == 0.5
+ # test transform on qscript
qs = qml.tape.QuantumScript(tape.operations, tape.measurements)
split, fn = split_non_commuting(qs)
+
+ assert len(split) == 1
assert all(isinstance(i_qs, qml.tape.QuantumScript) for i_qs in split)
assert fn([0.5]) == 0.5
spy.assert_not_called()
- @pytest.mark.parametrize("tape,expected", [(non_commuting_tape2, 2), (non_commuting_tape3, 3)])
- def test_non_commuting_group_right_number(self, tape, expected):
- """Test that the output is of the correct size"""
+ @pytest.mark.parametrize("meas_type", wire_meas_fn)
+ def test_wire_commuting_group_no_split(self, mocker, meas_type):
+ """Testing that commuting MPs initialized using wires or observables are not split"""
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.PauliZ(0)
+ qml.Hadamard(0)
+ qml.CNOT((0, 1))
+ meas_type(wires=[0])
+ meas_type(wires=[1])
+ meas_type(wires=[0, 1])
+ meas_type(op=qml.PauliZ(0))
+ meas_type(op=qml.PauliZ(0) @ qml.PauliZ(2))
+
+ # test transform on tape
+ tape = qml.tape.QuantumScript.from_queue(q)
+ split, fn = split_non_commuting(tape)
+
+ spy = mocker.spy(qml.math, "concatenate")
+
+ assert len(split) == 1
+ assert all(isinstance(t, qml.tape.QuantumScript) for t in split)
+ assert fn([0.5]) == 0.5
+
+ # test transform on qscript
+ qs = qml.tape.QuantumScript(tape.operations, tape.measurements)
+ split, fn = split_non_commuting(qs)
+
+ assert len(split) == 1
+ assert all(isinstance(i_qs, qml.tape.QuantumScript) for i_qs in split)
+ assert fn([0.5]) == 0.5
+
+ spy.assert_not_called()
+
+ @pytest.mark.parametrize("meas_type", obs_meas_fn)
+ @pytest.mark.parametrize("obs_list, expected", [(obs_list_2, 2), (obs_list_3, 3)])
+ def test_non_commuting_group_right_number(self, meas_type, obs_list, expected):
+ """Test that the no. of tapes after splitting into commuting groups is of the right size"""
+
+ # create a queue with several measurements of same type but with differnent non-commuting
+ # observables
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.PauliZ(0)
+ qml.Hadamard(0)
+ qml.CNOT((0, 1))
+ for obs in obs_list:
+ meas_type(op=obs)
+
+ # if MP type can accept wires, then add two extra measurements using wires and test no.
+ # of tapes after splitting commuting groups
+ if meas_type in wire_meas_fn:
+ meas_type(wires=[0])
+ meas_type(wires=[0, 1])
+
+ tape = qml.tape.QuantumScript.from_queue(q)
split, _ = split_non_commuting(tape)
assert len(split) == expected
@@ -88,30 +144,75 @@ def test_non_commuting_group_right_number(self, tape, expected):
split, _ = split_non_commuting(qs)
assert len(split) == expected
+ @pytest.mark.parametrize("meas_type", obs_meas_fn)
@pytest.mark.parametrize(
- "tape,group_coeffs",
- [(non_commuting_tape2, [[0, 2], [1, 3]]), (non_commuting_tape3, [[0, 3], [1, 4], [2, 5]])],
+ "obs_list, group_coeffs",
+ [(obs_list_2, [[1, 3], [0, 2, 4]]), (obs_list_3, [[0, 3], [1, 4], [2, 5]])],
)
- def test_non_commuting_group_right_reorder(self, tape, group_coeffs):
+ def test_non_commuting_group_right_reorder(self, meas_type, obs_list, group_coeffs):
"""Test that the output is of the correct order"""
- split, fn = split_non_commuting(tape)
- assert all(np.array(fn(group_coeffs)) == np.arange(len(split) * 2))
+ # create a queue with several measurements of same type but with differnent non-commuting
+ # observables
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.PauliZ(0)
+ qml.Hadamard(0)
+ qml.CNOT((0, 1))
+ for obs in obs_list:
+ meas_type(op=obs)
+
+ tape = qml.tape.QuantumScript.from_queue(q)
+ _, fn = split_non_commuting(tape)
+ assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements)))
qs = qml.tape.QuantumScript(tape.operations, tape.measurements)
- split, fn = split_non_commuting(qs)
- assert all(np.array(fn(group_coeffs)) == np.arange(len(split) * 2))
+ _, fn = split_non_commuting(qs)
+ assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements)))
- @pytest.mark.parametrize("meas_type", obs_fn)
+ @pytest.mark.parametrize("meas_type", wire_meas_fn)
+ @pytest.mark.parametrize(
+ "obs_list, group_coeffs",
+ [(obs_list_2, [[1, 3], [0, 2, 4, 5]]), (obs_list_3, [[1, 4], [2, 5], [0, 3, 6]])],
+ )
+ def test_wire_non_commuting_group_right_reorder(self, meas_type, obs_list, group_coeffs):
+ """Test that the output is of the correct order with wire MPs initialized using a
+ combination of wires and observables"""
+ # create a queue with several measurements of same type but with differnent non-commuting
+ # observables
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.PauliZ(0)
+ qml.Hadamard(0)
+ qml.CNOT((0, 1))
+ for obs in obs_list:
+ meas_type(op=obs)
+
+ # initialize measurements using wires
+ meas_type(wires=[0])
+
+ tape = qml.tape.QuantumScript.from_queue(q)
+ _, fn = split_non_commuting(tape)
+ assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements)))
+
+ qs = qml.tape.QuantumScript(tape.operations, tape.measurements)
+ _, fn = split_non_commuting(qs)
+ assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements)))
+
+ @pytest.mark.parametrize("meas_type", obs_meas_fn)
def test_different_measurement_types(self, meas_type):
- """Test that expval, var and sample are correctly reproduced"""
+ """Test that the measurements types of the split tapes are correct"""
with qml.queuing.AnnotatedQueue() as q:
qml.PauliZ(0)
qml.Hadamard(0)
qml.CNOT((0, 1))
- meas_type(qml.PauliZ(0) @ qml.PauliZ(1))
- meas_type(qml.PauliX(0) @ qml.PauliX(1))
- meas_type(qml.PauliZ(0))
- meas_type(qml.PauliX(0))
+ meas_type(op=qml.PauliZ(0) @ qml.PauliZ(1))
+ meas_type(op=qml.PauliX(0) @ qml.PauliX(1))
+ meas_type(op=qml.PauliZ(0))
+ meas_type(op=qml.PauliX(0))
+
+ # if the MP can also accept wires as arguments, add extra measurements to test
+ if meas_type in wire_meas_fn:
+ meas_type(wires=[0])
+ meas_type(wires=[0, 1])
+
tape = qml.tape.QuantumScript.from_queue(q)
the_return_type = tape.measurements[0].return_type
split, _ = split_non_commuting(tape)
@@ -125,45 +226,72 @@ def test_different_measurement_types(self, meas_type):
for meas in new_tape.measurements:
assert meas.return_type == the_return_type
- def test_mixed_measurement_types(self):
- """Test that mixing expval and var works correctly."""
+ @pytest.mark.parametrize("meas_type_1, meas_type_2", itertools.combinations(obs_meas_fn, 2))
+ def test_mixed_measurement_types(self, meas_type_1, meas_type_2):
+ """Test that mixing different combintations of MPs initialized using obs works correctly."""
with qml.queuing.AnnotatedQueue() as q:
qml.Hadamard(0)
qml.Hadamard(1)
- qml.expval(qml.PauliX(0))
- qml.expval(qml.PauliZ(1))
- qml.var(qml.PauliZ(0))
+ meas_type_1(op=qml.PauliX(0))
+ meas_type_1(op=qml.PauliZ(1))
+ meas_type_2(op=qml.PauliZ(0))
tape = qml.tape.QuantumScript.from_queue(q)
split, _ = split_non_commuting(tape)
assert len(split) == 2
+ assert qml.equal(split[0].measurements[0], meas_type_1(op=qml.PauliX(0)))
+ assert qml.equal(split[0].measurements[1], meas_type_1(op=qml.PauliZ(1)))
+ assert qml.equal(split[1].measurements[0], meas_type_2(op=qml.PauliZ(0)))
+
+ @pytest.mark.parametrize("meas_type_1, meas_type_2", itertools.combinations(wire_meas_fn, 2))
+ def test_mixed_wires_measurement_types(self, meas_type_1, meas_type_2):
+ """Test that mixing different combinations of MPs initialized using wires works correctly"""
with qml.queuing.AnnotatedQueue() as q:
qml.Hadamard(0)
qml.Hadamard(1)
- qml.expval(qml.PauliX(0))
- qml.var(qml.PauliZ(0))
- qml.expval(qml.PauliZ(1))
+ meas_type_1(op=qml.PauliX(0))
+ meas_type_1(wires=[1])
+ meas_type_2(wires=[0])
tape = qml.tape.QuantumScript.from_queue(q)
split, _ = split_non_commuting(tape)
assert len(split) == 2
- assert qml.equal(split[0].measurements[0], qml.expval(qml.PauliX(0)))
- assert qml.equal(split[0].measurements[1], qml.expval(qml.PauliZ(1)))
- assert qml.equal(split[1].measurements[0], qml.var(qml.PauliZ(0)))
+ assert qml.equal(split[0].measurements[0], meas_type_1(op=qml.PauliX(0)))
+ assert qml.equal(split[0].measurements[1], meas_type_1(wires=[1]))
+ assert qml.equal(split[1].measurements[0], meas_type_2(wires=[0]))
+
+ @pytest.mark.parametrize(
+ "meas_type_1, meas_type_2", itertools.product(obs_meas_fn, wire_meas_fn)
+ )
+ def test_mixed_wires_obs_measurement_types(self, meas_type_1, meas_type_2):
+ """Test that mixing different combinations of MPs initialized using wires and obs works
+ correctly"""
- def test_raise_not_supported(self):
- """Test that NotImplementedError is raised when probabilities or samples are called"""
with qml.queuing.AnnotatedQueue() as q:
- qml.expval(qml.PauliZ(0))
- qml.probs(wires=0)
+ qml.Hadamard(0)
+ qml.Hadamard(1)
+ meas_type_1(op=qml.PauliX(0))
+ meas_type_2(wires=[1])
+ meas_type_2(wires=[0, 1])
tape = qml.tape.QuantumScript.from_queue(q)
- with pytest.raises(NotImplementedError, match="non-commuting observables are used"):
- split_non_commuting(tape)
+ split, _ = split_non_commuting(tape)
+
+ assert len(split) == 2
+ assert qml.equal(split[0].measurements[0], meas_type_1(op=qml.PauliX(0)))
+ assert qml.equal(split[0].measurements[1], meas_type_2(wires=[1]))
+ assert qml.equal(split[1].measurements[0], meas_type_2(wires=[0, 1]))
+
+
+# measurements that require shots=True
+required_shot_meas_fn = [qml.sample, qml.counts]
+
+# measurements that can optionally have shots=True
+optional_shot_meas_fn = [qml.probs, qml.expval, qml.var]
class TestIntegration:
@@ -201,7 +329,7 @@ def circuit():
assert all(np.isclose(res, np.array([0.0, -1.0, 0.0, 0.0, 1.0, 1 / np.sqrt(2)])))
def test_expval_non_commuting_observables_qnode(self):
- """Test expval with multiple non-commuting operators as a tranform program on the qnode."""
+ """Test expval with multiple non-commuting operators as a transform program on the qnode."""
dev = qml.device("default.qubit", wires=6)
@qml.qnode(dev)
@@ -232,6 +360,41 @@ def circuit():
assert all(np.isclose(res, np.array([0.0, -1.0, 0.0, 0.0, 1.0, 1 / np.sqrt(2)])))
+ def test_expval_probs_non_commuting_observables_qnode(self):
+ """Test expval with multiple non-commuting operators and probs with non-commuting wires as a
+ transform program on the qnode."""
+ dev = qml.device("default.qubit", wires=6)
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.Hadamard(1)
+ qml.Hadamard(0)
+ qml.PauliZ(0)
+ qml.Hadamard(3)
+ qml.Hadamard(5)
+ qml.T(5)
+ return (
+ qml.probs(wires=[0, 1]),
+ qml.probs(wires=[1]),
+ qml.expval(qml.PauliZ(0)),
+ qml.expval(qml.PauliX(1) @ qml.PauliX(4)),
+ qml.expval(qml.PauliX(3)),
+ qml.expval(qml.PauliY(5)),
+ )
+
+ res = split_non_commuting(circuit)()
+
+ assert isinstance(res, tuple)
+ assert len(res) == 6
+ assert all(isinstance(r, np.ndarray) for r in res)
+
+ res_probs = qml.math.concatenate(res[:2], axis=0)
+ res_expval = qml.math.stack(res[2:])
+
+ assert all(np.isclose(res_probs, np.array([0.25, 0.25, 0.25, 0.25, 0.5, 0.5])))
+
+ assert all(np.isclose(res_expval, np.array([0.0, 0.0, 1.0, 1 / np.sqrt(2)])))
+
def test_shot_vector_support(self):
"""Test output is correct when using shot vectors"""
@@ -274,6 +437,66 @@ def circuit():
res, np.array([0.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1 / np.sqrt(2)]), atol=0.05
)
+ def test_shot_vector_support_sample(self):
+ """Test output is correct when using shots and sample and expval measurements"""
+
+ dev = qml.device("default.qubit", wires=2, shots=(10, 20))
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.PauliZ(0)
+ return (qml.sample(wires=[0, 1]), qml.expval(qml.PauliZ(0)))
+
+ res = split_non_commuting(circuit)()
+ assert isinstance(res, tuple)
+ assert len(res) == 2
+ assert all(isinstance(shot_res, tuple) for shot_res in res)
+ assert all(len(shot_res) == 2 for shot_res in res)
+ # pylint:disable=not-an-iterable
+ assert all(all(list(isinstance(r, np.ndarray) for r in shot_res)) for shot_res in res)
+
+ assert all(
+ shot_res[0].shape in [(10, 2), (20, 2)] and shot_res[1].shape == () for shot_res in res
+ )
+
+ # check all the wire samples are as expected
+ sample_res = qml.math.concatenate(
+ [qml.math.concatenate(shot_res[0], axis=0) for shot_res in res], axis=0
+ )
+ assert np.allclose(sample_res, 0.0, atol=0.05)
+
+ expval_res = qml.math.stack([shot_res[1] for shot_res in res])
+ assert np.allclose(expval_res, np.array([1.0, 1.0]), atol=0.05)
+
+ def test_shot_vector_support_counts(self):
+ """Test output is correct when using shots, counts and expval measurements"""
+
+ dev = qml.device("default.qubit", wires=2, shots=(10, 20))
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.PauliZ(0)
+ return (qml.counts(wires=[0, 1]), qml.expval(qml.PauliZ(0)))
+
+ res = split_non_commuting(circuit)()
+ assert isinstance(res, tuple)
+ assert len(res) == 2
+ assert all(isinstance(shot_res, tuple) for shot_res in res)
+ assert all(len(shot_res) == 2 for shot_res in res)
+ # pylint:disable=not-an-iterable
+ assert all(
+ isinstance(shot_res[0], dict) and isinstance(shot_res[1], np.ndarray)
+ for shot_res in res
+ )
+
+ assert all(shot_res[1].shape == () for shot_res in res)
+
+ # check all the wire counts are as expected
+ assert all(shot_res[0]["00"] in [10, 20] for shot_res in res)
+
+ expval_res = qml.math.stack([shot_res[1] for shot_res in res])
+ assert np.allclose(expval_res, np.array([1.0, 1.0]), atol=0.05)
+
# Autodiff tests
exp_res = np.array([0.77015115, -0.47942554, 0.87758256])
@@ -281,6 +504,17 @@ def circuit():
[[-4.20735492e-01, -4.20735492e-01], [-8.77582562e-01, 0.0], [-4.79425539e-01, 0.0]]
)
+exp_res_probs = np.array([0.88132907, 0.05746221, 0.05746221, 0.00374651, 0.0])
+exp_grad_probs = np.array(
+ [
+ [-0.22504026, -0.22504026],
+ [-0.01467251, 0.22504026],
+ [0.22504026, -0.01467251],
+ [0.01467251, 0.01467251],
+ [0.0, 0.0],
+ ]
+)
+
class TestAutodiffSplitNonCommuting:
"""Autodiff tests for all frameworks"""
@@ -310,6 +544,28 @@ def cost(params):
assert all(np.isclose(res, exp_res))
assert all(np.isclose(grad, exp_grad).flatten())
+ @pytest.mark.autograd
+ def test_split_with_autograd_probs(self):
+ """Test resulting after splitting non-commuting tapes with expval and probs measurements
+ are still differentiable with autograd"""
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circuit(params):
+ qml.RX(params[0], wires=0)
+ qml.RY(params[1], wires=1)
+ return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1)))
+
+ def cost(params):
+ res = split_non_commuting(circuit)(params)
+ return qml.math.concatenate([res[0]] + [qml.math.stack(res[1:])], axis=0)
+
+ params = pnp.array([0.5, 0.5])
+ res = cost(params)
+ grad = qml.jacobian(cost)(params)
+ assert all(np.isclose(res, exp_res_probs))
+ assert all(np.isclose(grad, exp_grad_probs).flatten())
+
@pytest.mark.jax
def test_split_with_jax(self):
"""Test that results after splitting are still differentiable with jax"""
@@ -330,10 +586,35 @@ def circuit(params):
)
params = jnp.array([0.5, 0.5])
- res = circuit(params)
- grad = jax.jacobian(circuit)(params)
- assert all(np.isclose(res, exp_res))
- assert all(np.isclose(grad, exp_grad, atol=1e-5).flatten())
+ res = split_non_commuting(circuit)(params)
+ grad = jax.jacobian(split_non_commuting(circuit))(params)
+ assert all(np.isclose(res, exp_res, atol=0.05))
+ assert all(np.isclose(grad, exp_grad, atol=0.05).flatten())
+
+ @pytest.mark.jax
+ def test_split_with_jax_probs(self):
+ """Test resulting after splitting non-commuting tapes with expval and probs measurements
+ are still differentiable with jax"""
+ import jax
+ import jax.numpy as jnp
+
+ dev = qml.device("default.qubit.jax", wires=2)
+
+ @qml.qnode(dev)
+ def circuit(params):
+ qml.RX(params[0], wires=0)
+ qml.RY(params[1], wires=1)
+ return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1)))
+
+ params = jnp.array([0.5, 0.5])
+ res = split_non_commuting(circuit)(params)
+ res = jnp.concatenate([res[0]] + [jnp.stack(res[1:])], axis=0)
+
+ grad = jax.jacobian(split_non_commuting(circuit))(params)
+ grad = jnp.concatenate([grad[0]] + [jnp.stack(grad[1:])], axis=0)
+
+ assert all(np.isclose(res, exp_res_probs, atol=0.05))
+ assert all(np.isclose(grad, exp_grad_probs, atol=0.05).flatten())
@pytest.mark.jax
def test_split_with_jax_multi_params(self):
@@ -358,8 +639,8 @@ def circuit(x, y):
x = jnp.array(0.5)
y = jnp.array(0.5)
- res = circuit(x, y)
- grad = jax.jacobian(circuit, argnums=[0, 1])(x, y)
+ res = split_non_commuting(circuit)(x, y)
+ grad = jax.jacobian(split_non_commuting(circuit), argnums=[0, 1])(x, y)
assert all(np.isclose(res, exp_res))
@@ -373,6 +654,49 @@ def circuit(x, y):
assert np.allclose(meas_grad, exp_grad[i], atol=1e-5)
+ @pytest.mark.jax
+ def test_split_with_jax_multi_params_probs(self):
+ """Test that results after splitting are still differentiable with jax
+ with multiple parameters"""
+
+ import jax
+ import jax.numpy as jnp
+
+ dev = qml.device("default.qubit.jax", wires=2)
+
+ @qml.qnode(dev)
+ def circuit(x, y):
+ qml.RX(x, wires=0)
+ qml.RY(y, wires=1)
+ return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1)))
+
+ x = jnp.array(0.5)
+ y = jnp.array(0.5)
+
+ res = split_non_commuting(circuit)(x, y)
+ res = jnp.concatenate([res[0]] + [jnp.stack(res[1:])], axis=0)
+ assert all(np.isclose(res, exp_res_probs))
+
+ grad = jax.jacobian(split_non_commuting(circuit), argnums=[0, 1])(x, y)
+
+ assert isinstance(grad, tuple)
+ assert len(grad) == 2
+
+ for meas_grad in grad:
+ assert isinstance(meas_grad, tuple)
+ assert len(meas_grad) == 2
+ assert all(isinstance(g, jnp.ndarray) for g in meas_grad)
+
+ # reshape the returned gradient to the right shape
+ grad = jnp.concatenate(
+ [
+ jnp.concatenate([grad[0][0].reshape(-1, 1), grad[0][1].reshape(-1, 1)], axis=1),
+ jnp.concatenate([grad[1][0].reshape(-1, 1), grad[1][1].reshape(-1, 1)], axis=1),
+ ],
+ axis=0,
+ )
+ assert all(np.isclose(grad, exp_grad_probs, atol=0.05).flatten())
+
@pytest.mark.jax
def test_split_with_jax_jit(self):
"""Test that results after splitting are still differentiable with jax-jit"""
@@ -399,6 +723,32 @@ def circuit(params):
assert all(np.isclose(res, exp_res))
assert all(np.isclose(grad, exp_grad, atol=1e-5).flatten())
+ @pytest.mark.jax
+ def test_split_with_jax_jit_probs(self):
+ """Test resulting after splitting non-commuting tapes with expval and probs measurements
+ are still differentiable with jax"""
+
+ import jax
+ import jax.numpy as jnp
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circuit(params):
+ qml.RX(params[0], wires=0)
+ qml.RY(params[1], wires=1)
+ return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1)))
+
+ params = jnp.array([0.5, 0.5])
+ res = split_non_commuting(circuit)(params)
+ res = jnp.concatenate([res[0]] + [jnp.stack(res[1:])], axis=0)
+
+ grad = jax.jacobian(split_non_commuting(circuit))(params)
+ grad = jnp.concatenate([grad[0]] + [jnp.stack(grad[1:])], axis=0)
+
+ assert all(np.isclose(res, exp_res_probs, atol=0.05))
+ assert all(np.isclose(grad, exp_grad_probs, atol=0.05).flatten())
+
@pytest.mark.torch
def test_split_with_torch(self):
"""Test that results after splitting are still differentiable with torch"""
@@ -419,7 +769,7 @@ def circuit(params):
)
def cost(params):
- res = circuit(params)
+ res = split_non_commuting(circuit)(params)
return qml.math.stack(res)
params = torch.tensor([0.5, 0.5], requires_grad=True)
@@ -428,6 +778,32 @@ def cost(params):
assert all(np.isclose(res.detach().numpy(), exp_res))
assert all(np.isclose(grad.detach().numpy(), exp_grad, atol=1e-5).flatten())
+ @pytest.mark.torch
+ def test_split_with_torch_probs(self):
+ """Test resulting after splitting non-commuting tapes with expval and probs measurements
+ are still differentiable with torch"""
+
+ import torch
+ from torch.autograd.functional import jacobian
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circuit(params):
+ qml.RX(params[0], wires=0)
+ qml.RY(params[1], wires=1)
+ return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1)))
+
+ def cost(params):
+ res = split_non_commuting(circuit)(params)
+ return qml.math.concatenate([res[0]] + [qml.math.stack(res[1:])], axis=0)
+
+ params = torch.tensor([0.5, 0.5], requires_grad=True)
+ res = cost(params)
+ grad = jacobian(cost, (params))
+ assert all(np.isclose(res.detach().numpy(), exp_res_probs))
+ assert all(np.isclose(grad.detach().numpy(), exp_grad_probs, atol=1e-5).flatten())
+
@pytest.mark.tf
def test_split_with_tf(self):
"""Test that results after splitting are still differentiable with tf"""
@@ -449,9 +825,32 @@ def circuit(params):
params = tf.Variable([0.5, 0.5])
res = circuit(params)
with tf.GradientTape() as tape:
- loss = circuit(params)
+ loss = split_non_commuting(circuit)(params)
loss = tf.stack(loss)
grad = tape.jacobian(loss, params)
assert all(np.isclose(res, exp_res))
assert all(np.isclose(grad, exp_grad, atol=1e-5).flatten())
+
+ @pytest.mark.tf
+ def test_split_with_tf_probs(self):
+ """Test that results after splitting are still differentiable with tf"""
+
+ import tensorflow as tf
+
+ dev = qml.device("default.qubit.tf", wires=2)
+
+ @qml.qnode(dev)
+ def circuit(params):
+ qml.RX(params[0], wires=0)
+ qml.RY(params[1], wires=1)
+ return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1)))
+
+ params = tf.Variable([0.5, 0.5])
+ with tf.GradientTape() as tape:
+ res = split_non_commuting(circuit)(params)
+ res = tf.concat([res[0]] + [tf.stack(res[1:])], axis=0)
+
+ grad = tape.jacobian(res, params)
+ assert all(np.isclose(res, exp_res_probs))
+ assert all(np.isclose(grad, exp_grad_probs, atol=1e-5).flatten())