diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4d252aed799..4d73c06c314 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -8,6 +8,10 @@ * Update `tests/ops/functions/conftest.py` to ensure all operator types are tested for validity. [(#4978)](https://github.com/PennyLaneAI/pennylane/pull/4978) + +

Community contributions 🥳

+ +* The transform ``split_non_commuting`` now accepts measurements of type `probs`, `sample` and `counts` which accept both wires and observables. [(#4972)](https://github.com/PennyLaneAI/pennylane/pull/4972)

Breaking changes 💔

@@ -21,4 +25,5 @@ This release contains contributions from (in alphabetical order): -Matthew Silverman +Abhishek Abhishek, +Matthew Silverman. \ No newline at end of file diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py index 06719ff8013..3dec7b19c5a 100644 --- a/pennylane/transforms/split_non_commuting.py +++ b/pennylane/transforms/split_non_commuting.py @@ -19,7 +19,6 @@ from functools import reduce import pennylane as qml -from pennylane.measurements import ProbabilityMP, SampleMP from pennylane.transforms import transform @@ -34,13 +33,13 @@ def split_non_commuting(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.Quantu non-commuting observables to measure. Returns: - qnode (QNode) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform `. + qnode (QNode) or tuple[List[QuantumTape], function]: The transformed circuit as described in + :func:`qml.transform `. **Example** - This transform allows us to transform a QNode that measures - non-commuting observables to *multiple* circuit executions - with qubit-wise commuting groups: + This transform allows us to transform a QNode that measures non-commuting observables to + *multiple* circuit executions with qubit-wise commuting groups: .. code-block:: python3 @@ -52,7 +51,8 @@ def circuit(x): qml.RX(x,wires=0) return [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0))] - Instead of decorating the QNode, we can also create a new function that yields the same result in the following way: + Instead of decorating the QNode, we can also create a new function that yields the same result + in the following way: .. code-block:: python3 @@ -70,8 +70,10 @@ def circuit(x): \ 0: ──RX(0.50)─┤ - Note that while internally multiple QNodes are created, the end result has the same ordering as the user provides in the return statement. - Here is a more involved example where we can see the different ordering at the execution level but restoring the original ordering in the output: + Note that while internally multiple QNodes are created, the end result has the same ordering as + the user provides in the return statement. + Here is a more involved example where we can see the different ordering at the execution level + but restoring the original ordering in the output: .. code-block:: python3 @@ -95,8 +97,10 @@ def circuit0(x): 0: ──RY(0.79)──RX(0.79)─┤ ╭ 1: ─────────────────────┤ ╰ - Yet, executing it returns the original ordering of the expectation values. The outputs correspond to - :math:`(\langle \sigma_x^0 \rangle, \langle \sigma_z^0 \rangle, \langle \sigma_y^1 \rangle, \langle \sigma_z^0\sigma_z^1 \rangle)`. + Yet, executing it returns the original ordering of the expectation values. The outputs + correspond to + :math:`(\langle \sigma_x^0 \rangle, \langle \sigma_z^0 \rangle, \langle \sigma_y^1 \rangle, + \langle \sigma_z^0\sigma_z^1 \rangle)`. >>> circuit0([np.pi/4, np.pi/4]) [0.7071067811865475, 0.49999999999999994, 0.0, 0.49999999999999994] @@ -105,7 +109,8 @@ def circuit0(x): .. details:: :title: Usage Details - Internally, this function works with tapes. We can create a tape with non-commuting observables: + Internally, this function works with tapes. We can create a tape with non-commuting + observables: .. code-block:: python3 @@ -119,7 +124,8 @@ def circuit0(x): >>> [t.observables for t in tapes] [[expval(PauliZ(wires=[0]))], [expval(PauliY(wires=[0]))]] - The processing function becomes important when creating the commuting groups as the order of the inputs has been modified: + The processing function becomes important when creating the commuting groups as the order + of the inputs has been modified: .. code-block:: python3 @@ -133,7 +139,8 @@ def circuit0(x): tapes, processing_fn = qml.transforms.split_non_commuting(tape) - In this example, the groupings are ``group_coeffs = [[0,2], [1,3]]`` and ``processing_fn`` makes sure that the final output is of the same shape and ordering: + In this example, the groupings are ``group_coeffs = [[0,2], [1,3]]`` and ``processing_fn`` + makes sure that the final output is of the same shape and ordering: >>> processing_fn([t.measurements for t in tapes]) (expval(PauliZ(wires=[0]) @ PauliZ(wires=[1])), @@ -141,25 +148,54 @@ def circuit0(x): expval(PauliZ(wires=[0])), expval(PauliX(wires=[0]))) - """ + Measurements that accept both observables and ``wires`` so that e.g. ``qml.counts``, + ``qml.probs`` and ``qml.sample`` can also be used. When initialized using only ``wires``, + these measurements are interpreted as measuring with respect to the observable + ``qml.PauliZ(wires[0])@qml.PauliZ(wires[1])@...@qml.PauliZ(wires[len(wires)-1])`` + + .. code-block:: python3 - # TODO: allow for samples and probs - if any(isinstance(m, (SampleMP, ProbabilityMP)) for m in tape.measurements): - raise NotImplementedError( - "When non-commuting observables are used, only `qml.expval` and `qml.var` are supported." - ) + measurements = [ + qml.expval(qml.PauliX(0)), + qml.probs(wires=[1]), + qml.probs(wires=[0, 1]) + ] + tape = qml.tape.QuantumTape(measurements=measurements) + + tapes, processing_fn = qml.transforms.split_non_commuting(tape) + + This results in two tapes, each with commuting measurements: + + >>> [t.measurements for t in tapes] + [[expval(PauliX(wires=[0])), probs(wires=[1])], [probs(wires=[0, 1])]] + """ - obs_list = tape.observables + # Construct a list of observables to group based on the measurements in the tape + obs_list = [] + for obs in tape.observables: + # observable provided for a measurement + if isinstance(obs, qml.operation.Observable): + obs_list.append(obs) + # measurements using wires instead of observables + else: + # create the PauliZ tensor product observable when only wires are provided for the + # measurements + # TODO: Revisit when qml.prod is compatible with qml.pauli.group_observables + pauliz_obs = qml.PauliZ(obs.wires[0]) + for wire in obs.wires[1:]: + pauliz_obs = pauliz_obs @ qml.PauliZ(wire) + + obs_list.append(pauliz_obs) # If there is more than one group of commuting observables, split tapes - groups, group_coeffs = qml.pauli.group_observables(obs_list, range(len(obs_list))) - if len(groups) > 1: + _, group_coeffs = qml.pauli.group_observables(obs_list, range(len(obs_list))) + if len(group_coeffs) > 1: # make one tape per commuting group tapes = [] - for group, indices in zip(groups, group_coeffs): + for indices in group_coeffs: new_tape = tape.__class__( tape.operations, - (tape.measurements[i].__class__(obs=o) for o, i in zip(group, indices)), + (tape.measurements[i] for i in indices), ) tapes.append(new_tape) diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py index fb0399e48ba..4b2ae1d7368 100644 --- a/tests/transforms/test_split_non_commuting.py +++ b/tests/transforms/test_split_non_commuting.py @@ -13,6 +13,7 @@ # limitations under the License. """ Tests for the transform ``qml.transform.split_non_commuting()`` """ # pylint: disable=no-self-use, import-outside-toplevel, no-member, import-error +import itertools import pytest import numpy as np import pennylane as qml @@ -20,67 +21,122 @@ from pennylane.transforms import split_non_commuting -### example tape with 3 commuting groups [[0,3],[1,4],[2,5]] -with qml.queuing.AnnotatedQueue() as q3: - qml.PauliZ(0) - qml.Hadamard(0) - qml.CNOT((0, 1)) - qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) - qml.expval(qml.PauliX(0) @ qml.PauliX(1)) - qml.expval(qml.PauliY(0) @ qml.PauliY(1)) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliX(0)) - qml.expval(qml.PauliY(0)) - -non_commuting_tape3 = qml.tape.QuantumScript.from_queue(q3) -### example tape with 2 -commuting groups [[0,2],[1,3]] -with qml.queuing.AnnotatedQueue() as q2: - qml.PauliZ(0) - qml.Hadamard(0) - qml.CNOT((0, 1)) - qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) - qml.expval(qml.PauliX(0) @ qml.PauliX(1)) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliX(0)) - -non_commuting_tape2 = qml.tape.QuantumScript.from_queue(q2) -# For testing different observable types -obs_fn = [qml.expval, qml.var] +# list of observables with 2 commuting groups [[1, 3], [0, 2, 4]] +obs_list_2 = [ + qml.PauliZ(0) @ qml.PauliZ(1), + qml.PauliX(0) @ qml.PauliX(1), + qml.PauliZ(0), + qml.PauliX(0), + qml.PauliZ(1), +] + +# list of observables with 3 commuting groups [[0,3], [1,4], [2,5]] +obs_list_3 = [ + qml.PauliZ(0) @ qml.PauliZ(1), + qml.PauliX(0) @ qml.PauliX(1), + qml.PauliY(0) @ qml.PauliY(1), + qml.PauliZ(0), + qml.PauliX(0), + qml.PauliY(0), +] + +# measurements that accept observables as arguments +obs_meas_fn = [qml.expval, qml.var, qml.probs, qml.counts, qml.sample] + +# measurements that accept wires as arguments +wire_meas_fn = [qml.probs, qml.counts, qml.sample] class TestUnittestSplitNonCommuting: """Unit tests on ``qml.transforms.split_non_commuting()``""" - def test_commuting_group_no_split(self, mocker): - """Testing that commuting groups are not split""" + @pytest.mark.parametrize("meas_type", obs_meas_fn) + def test_commuting_group_no_split(self, mocker, meas_type): + """Testing that commuting groups are not split for all supported measurement types""" with qml.queuing.AnnotatedQueue() as q: qml.PauliZ(0) qml.Hadamard(0) qml.CNOT((0, 1)) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliX(1)) - qml.expval(qml.PauliZ(2)) - qml.expval(qml.PauliZ(0) @ qml.PauliZ(3)) + meas_type(op=qml.PauliZ(0)) + meas_type(op=qml.PauliZ(0)) + meas_type(op=qml.PauliX(1)) + meas_type(op=qml.PauliZ(2)) + meas_type(op=qml.PauliZ(0) @ qml.PauliZ(3)) + # test transform on tape tape = qml.tape.QuantumScript.from_queue(q) split, fn = split_non_commuting(tape) spy = mocker.spy(qml.math, "concatenate") + assert len(split) == 1 assert all(isinstance(t, qml.tape.QuantumScript) for t in split) assert fn([0.5]) == 0.5 + # test transform on qscript qs = qml.tape.QuantumScript(tape.operations, tape.measurements) split, fn = split_non_commuting(qs) + + assert len(split) == 1 assert all(isinstance(i_qs, qml.tape.QuantumScript) for i_qs in split) assert fn([0.5]) == 0.5 spy.assert_not_called() - @pytest.mark.parametrize("tape,expected", [(non_commuting_tape2, 2), (non_commuting_tape3, 3)]) - def test_non_commuting_group_right_number(self, tape, expected): - """Test that the output is of the correct size""" + @pytest.mark.parametrize("meas_type", wire_meas_fn) + def test_wire_commuting_group_no_split(self, mocker, meas_type): + """Testing that commuting MPs initialized using wires or observables are not split""" + with qml.queuing.AnnotatedQueue() as q: + qml.PauliZ(0) + qml.Hadamard(0) + qml.CNOT((0, 1)) + meas_type(wires=[0]) + meas_type(wires=[1]) + meas_type(wires=[0, 1]) + meas_type(op=qml.PauliZ(0)) + meas_type(op=qml.PauliZ(0) @ qml.PauliZ(2)) + + # test transform on tape + tape = qml.tape.QuantumScript.from_queue(q) + split, fn = split_non_commuting(tape) + + spy = mocker.spy(qml.math, "concatenate") + + assert len(split) == 1 + assert all(isinstance(t, qml.tape.QuantumScript) for t in split) + assert fn([0.5]) == 0.5 + + # test transform on qscript + qs = qml.tape.QuantumScript(tape.operations, tape.measurements) + split, fn = split_non_commuting(qs) + + assert len(split) == 1 + assert all(isinstance(i_qs, qml.tape.QuantumScript) for i_qs in split) + assert fn([0.5]) == 0.5 + + spy.assert_not_called() + + @pytest.mark.parametrize("meas_type", obs_meas_fn) + @pytest.mark.parametrize("obs_list, expected", [(obs_list_2, 2), (obs_list_3, 3)]) + def test_non_commuting_group_right_number(self, meas_type, obs_list, expected): + """Test that the no. of tapes after splitting into commuting groups is of the right size""" + + # create a queue with several measurements of same type but with differnent non-commuting + # observables + with qml.queuing.AnnotatedQueue() as q: + qml.PauliZ(0) + qml.Hadamard(0) + qml.CNOT((0, 1)) + for obs in obs_list: + meas_type(op=obs) + + # if MP type can accept wires, then add two extra measurements using wires and test no. + # of tapes after splitting commuting groups + if meas_type in wire_meas_fn: + meas_type(wires=[0]) + meas_type(wires=[0, 1]) + + tape = qml.tape.QuantumScript.from_queue(q) split, _ = split_non_commuting(tape) assert len(split) == expected @@ -88,30 +144,75 @@ def test_non_commuting_group_right_number(self, tape, expected): split, _ = split_non_commuting(qs) assert len(split) == expected + @pytest.mark.parametrize("meas_type", obs_meas_fn) @pytest.mark.parametrize( - "tape,group_coeffs", - [(non_commuting_tape2, [[0, 2], [1, 3]]), (non_commuting_tape3, [[0, 3], [1, 4], [2, 5]])], + "obs_list, group_coeffs", + [(obs_list_2, [[1, 3], [0, 2, 4]]), (obs_list_3, [[0, 3], [1, 4], [2, 5]])], ) - def test_non_commuting_group_right_reorder(self, tape, group_coeffs): + def test_non_commuting_group_right_reorder(self, meas_type, obs_list, group_coeffs): """Test that the output is of the correct order""" - split, fn = split_non_commuting(tape) - assert all(np.array(fn(group_coeffs)) == np.arange(len(split) * 2)) + # create a queue with several measurements of same type but with differnent non-commuting + # observables + with qml.queuing.AnnotatedQueue() as q: + qml.PauliZ(0) + qml.Hadamard(0) + qml.CNOT((0, 1)) + for obs in obs_list: + meas_type(op=obs) + + tape = qml.tape.QuantumScript.from_queue(q) + _, fn = split_non_commuting(tape) + assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements))) qs = qml.tape.QuantumScript(tape.operations, tape.measurements) - split, fn = split_non_commuting(qs) - assert all(np.array(fn(group_coeffs)) == np.arange(len(split) * 2)) + _, fn = split_non_commuting(qs) + assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements))) - @pytest.mark.parametrize("meas_type", obs_fn) + @pytest.mark.parametrize("meas_type", wire_meas_fn) + @pytest.mark.parametrize( + "obs_list, group_coeffs", + [(obs_list_2, [[1, 3], [0, 2, 4, 5]]), (obs_list_3, [[1, 4], [2, 5], [0, 3, 6]])], + ) + def test_wire_non_commuting_group_right_reorder(self, meas_type, obs_list, group_coeffs): + """Test that the output is of the correct order with wire MPs initialized using a + combination of wires and observables""" + # create a queue with several measurements of same type but with differnent non-commuting + # observables + with qml.queuing.AnnotatedQueue() as q: + qml.PauliZ(0) + qml.Hadamard(0) + qml.CNOT((0, 1)) + for obs in obs_list: + meas_type(op=obs) + + # initialize measurements using wires + meas_type(wires=[0]) + + tape = qml.tape.QuantumScript.from_queue(q) + _, fn = split_non_commuting(tape) + assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements))) + + qs = qml.tape.QuantumScript(tape.operations, tape.measurements) + _, fn = split_non_commuting(qs) + assert all(np.array(fn(group_coeffs)) == np.arange(len(tape.measurements))) + + @pytest.mark.parametrize("meas_type", obs_meas_fn) def test_different_measurement_types(self, meas_type): - """Test that expval, var and sample are correctly reproduced""" + """Test that the measurements types of the split tapes are correct""" with qml.queuing.AnnotatedQueue() as q: qml.PauliZ(0) qml.Hadamard(0) qml.CNOT((0, 1)) - meas_type(qml.PauliZ(0) @ qml.PauliZ(1)) - meas_type(qml.PauliX(0) @ qml.PauliX(1)) - meas_type(qml.PauliZ(0)) - meas_type(qml.PauliX(0)) + meas_type(op=qml.PauliZ(0) @ qml.PauliZ(1)) + meas_type(op=qml.PauliX(0) @ qml.PauliX(1)) + meas_type(op=qml.PauliZ(0)) + meas_type(op=qml.PauliX(0)) + + # if the MP can also accept wires as arguments, add extra measurements to test + if meas_type in wire_meas_fn: + meas_type(wires=[0]) + meas_type(wires=[0, 1]) + tape = qml.tape.QuantumScript.from_queue(q) the_return_type = tape.measurements[0].return_type split, _ = split_non_commuting(tape) @@ -125,45 +226,72 @@ def test_different_measurement_types(self, meas_type): for meas in new_tape.measurements: assert meas.return_type == the_return_type - def test_mixed_measurement_types(self): - """Test that mixing expval and var works correctly.""" + @pytest.mark.parametrize("meas_type_1, meas_type_2", itertools.combinations(obs_meas_fn, 2)) + def test_mixed_measurement_types(self, meas_type_1, meas_type_2): + """Test that mixing different combintations of MPs initialized using obs works correctly.""" with qml.queuing.AnnotatedQueue() as q: qml.Hadamard(0) qml.Hadamard(1) - qml.expval(qml.PauliX(0)) - qml.expval(qml.PauliZ(1)) - qml.var(qml.PauliZ(0)) + meas_type_1(op=qml.PauliX(0)) + meas_type_1(op=qml.PauliZ(1)) + meas_type_2(op=qml.PauliZ(0)) tape = qml.tape.QuantumScript.from_queue(q) split, _ = split_non_commuting(tape) assert len(split) == 2 + assert qml.equal(split[0].measurements[0], meas_type_1(op=qml.PauliX(0))) + assert qml.equal(split[0].measurements[1], meas_type_1(op=qml.PauliZ(1))) + assert qml.equal(split[1].measurements[0], meas_type_2(op=qml.PauliZ(0))) + + @pytest.mark.parametrize("meas_type_1, meas_type_2", itertools.combinations(wire_meas_fn, 2)) + def test_mixed_wires_measurement_types(self, meas_type_1, meas_type_2): + """Test that mixing different combinations of MPs initialized using wires works correctly""" with qml.queuing.AnnotatedQueue() as q: qml.Hadamard(0) qml.Hadamard(1) - qml.expval(qml.PauliX(0)) - qml.var(qml.PauliZ(0)) - qml.expval(qml.PauliZ(1)) + meas_type_1(op=qml.PauliX(0)) + meas_type_1(wires=[1]) + meas_type_2(wires=[0]) tape = qml.tape.QuantumScript.from_queue(q) split, _ = split_non_commuting(tape) assert len(split) == 2 - assert qml.equal(split[0].measurements[0], qml.expval(qml.PauliX(0))) - assert qml.equal(split[0].measurements[1], qml.expval(qml.PauliZ(1))) - assert qml.equal(split[1].measurements[0], qml.var(qml.PauliZ(0))) + assert qml.equal(split[0].measurements[0], meas_type_1(op=qml.PauliX(0))) + assert qml.equal(split[0].measurements[1], meas_type_1(wires=[1])) + assert qml.equal(split[1].measurements[0], meas_type_2(wires=[0])) + + @pytest.mark.parametrize( + "meas_type_1, meas_type_2", itertools.product(obs_meas_fn, wire_meas_fn) + ) + def test_mixed_wires_obs_measurement_types(self, meas_type_1, meas_type_2): + """Test that mixing different combinations of MPs initialized using wires and obs works + correctly""" - def test_raise_not_supported(self): - """Test that NotImplementedError is raised when probabilities or samples are called""" with qml.queuing.AnnotatedQueue() as q: - qml.expval(qml.PauliZ(0)) - qml.probs(wires=0) + qml.Hadamard(0) + qml.Hadamard(1) + meas_type_1(op=qml.PauliX(0)) + meas_type_2(wires=[1]) + meas_type_2(wires=[0, 1]) tape = qml.tape.QuantumScript.from_queue(q) - with pytest.raises(NotImplementedError, match="non-commuting observables are used"): - split_non_commuting(tape) + split, _ = split_non_commuting(tape) + + assert len(split) == 2 + assert qml.equal(split[0].measurements[0], meas_type_1(op=qml.PauliX(0))) + assert qml.equal(split[0].measurements[1], meas_type_2(wires=[1])) + assert qml.equal(split[1].measurements[0], meas_type_2(wires=[0, 1])) + + +# measurements that require shots=True +required_shot_meas_fn = [qml.sample, qml.counts] + +# measurements that can optionally have shots=True +optional_shot_meas_fn = [qml.probs, qml.expval, qml.var] class TestIntegration: @@ -201,7 +329,7 @@ def circuit(): assert all(np.isclose(res, np.array([0.0, -1.0, 0.0, 0.0, 1.0, 1 / np.sqrt(2)]))) def test_expval_non_commuting_observables_qnode(self): - """Test expval with multiple non-commuting operators as a tranform program on the qnode.""" + """Test expval with multiple non-commuting operators as a transform program on the qnode.""" dev = qml.device("default.qubit", wires=6) @qml.qnode(dev) @@ -232,6 +360,41 @@ def circuit(): assert all(np.isclose(res, np.array([0.0, -1.0, 0.0, 0.0, 1.0, 1 / np.sqrt(2)]))) + def test_expval_probs_non_commuting_observables_qnode(self): + """Test expval with multiple non-commuting operators and probs with non-commuting wires as a + transform program on the qnode.""" + dev = qml.device("default.qubit", wires=6) + + @qml.qnode(dev) + def circuit(): + qml.Hadamard(1) + qml.Hadamard(0) + qml.PauliZ(0) + qml.Hadamard(3) + qml.Hadamard(5) + qml.T(5) + return ( + qml.probs(wires=[0, 1]), + qml.probs(wires=[1]), + qml.expval(qml.PauliZ(0)), + qml.expval(qml.PauliX(1) @ qml.PauliX(4)), + qml.expval(qml.PauliX(3)), + qml.expval(qml.PauliY(5)), + ) + + res = split_non_commuting(circuit)() + + assert isinstance(res, tuple) + assert len(res) == 6 + assert all(isinstance(r, np.ndarray) for r in res) + + res_probs = qml.math.concatenate(res[:2], axis=0) + res_expval = qml.math.stack(res[2:]) + + assert all(np.isclose(res_probs, np.array([0.25, 0.25, 0.25, 0.25, 0.5, 0.5]))) + + assert all(np.isclose(res_expval, np.array([0.0, 0.0, 1.0, 1 / np.sqrt(2)]))) + def test_shot_vector_support(self): """Test output is correct when using shot vectors""" @@ -274,6 +437,66 @@ def circuit(): res, np.array([0.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1 / np.sqrt(2)]), atol=0.05 ) + def test_shot_vector_support_sample(self): + """Test output is correct when using shots and sample and expval measurements""" + + dev = qml.device("default.qubit", wires=2, shots=(10, 20)) + + @qml.qnode(dev) + def circuit(): + qml.PauliZ(0) + return (qml.sample(wires=[0, 1]), qml.expval(qml.PauliZ(0))) + + res = split_non_commuting(circuit)() + assert isinstance(res, tuple) + assert len(res) == 2 + assert all(isinstance(shot_res, tuple) for shot_res in res) + assert all(len(shot_res) == 2 for shot_res in res) + # pylint:disable=not-an-iterable + assert all(all(list(isinstance(r, np.ndarray) for r in shot_res)) for shot_res in res) + + assert all( + shot_res[0].shape in [(10, 2), (20, 2)] and shot_res[1].shape == () for shot_res in res + ) + + # check all the wire samples are as expected + sample_res = qml.math.concatenate( + [qml.math.concatenate(shot_res[0], axis=0) for shot_res in res], axis=0 + ) + assert np.allclose(sample_res, 0.0, atol=0.05) + + expval_res = qml.math.stack([shot_res[1] for shot_res in res]) + assert np.allclose(expval_res, np.array([1.0, 1.0]), atol=0.05) + + def test_shot_vector_support_counts(self): + """Test output is correct when using shots, counts and expval measurements""" + + dev = qml.device("default.qubit", wires=2, shots=(10, 20)) + + @qml.qnode(dev) + def circuit(): + qml.PauliZ(0) + return (qml.counts(wires=[0, 1]), qml.expval(qml.PauliZ(0))) + + res = split_non_commuting(circuit)() + assert isinstance(res, tuple) + assert len(res) == 2 + assert all(isinstance(shot_res, tuple) for shot_res in res) + assert all(len(shot_res) == 2 for shot_res in res) + # pylint:disable=not-an-iterable + assert all( + isinstance(shot_res[0], dict) and isinstance(shot_res[1], np.ndarray) + for shot_res in res + ) + + assert all(shot_res[1].shape == () for shot_res in res) + + # check all the wire counts are as expected + assert all(shot_res[0]["00"] in [10, 20] for shot_res in res) + + expval_res = qml.math.stack([shot_res[1] for shot_res in res]) + assert np.allclose(expval_res, np.array([1.0, 1.0]), atol=0.05) + # Autodiff tests exp_res = np.array([0.77015115, -0.47942554, 0.87758256]) @@ -281,6 +504,17 @@ def circuit(): [[-4.20735492e-01, -4.20735492e-01], [-8.77582562e-01, 0.0], [-4.79425539e-01, 0.0]] ) +exp_res_probs = np.array([0.88132907, 0.05746221, 0.05746221, 0.00374651, 0.0]) +exp_grad_probs = np.array( + [ + [-0.22504026, -0.22504026], + [-0.01467251, 0.22504026], + [0.22504026, -0.01467251], + [0.01467251, 0.01467251], + [0.0, 0.0], + ] +) + class TestAutodiffSplitNonCommuting: """Autodiff tests for all frameworks""" @@ -310,6 +544,28 @@ def cost(params): assert all(np.isclose(res, exp_res)) assert all(np.isclose(grad, exp_grad).flatten()) + @pytest.mark.autograd + def test_split_with_autograd_probs(self): + """Test resulting after splitting non-commuting tapes with expval and probs measurements + are still differentiable with autograd""" + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circuit(params): + qml.RX(params[0], wires=0) + qml.RY(params[1], wires=1) + return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1))) + + def cost(params): + res = split_non_commuting(circuit)(params) + return qml.math.concatenate([res[0]] + [qml.math.stack(res[1:])], axis=0) + + params = pnp.array([0.5, 0.5]) + res = cost(params) + grad = qml.jacobian(cost)(params) + assert all(np.isclose(res, exp_res_probs)) + assert all(np.isclose(grad, exp_grad_probs).flatten()) + @pytest.mark.jax def test_split_with_jax(self): """Test that results after splitting are still differentiable with jax""" @@ -330,10 +586,35 @@ def circuit(params): ) params = jnp.array([0.5, 0.5]) - res = circuit(params) - grad = jax.jacobian(circuit)(params) - assert all(np.isclose(res, exp_res)) - assert all(np.isclose(grad, exp_grad, atol=1e-5).flatten()) + res = split_non_commuting(circuit)(params) + grad = jax.jacobian(split_non_commuting(circuit))(params) + assert all(np.isclose(res, exp_res, atol=0.05)) + assert all(np.isclose(grad, exp_grad, atol=0.05).flatten()) + + @pytest.mark.jax + def test_split_with_jax_probs(self): + """Test resulting after splitting non-commuting tapes with expval and probs measurements + are still differentiable with jax""" + import jax + import jax.numpy as jnp + + dev = qml.device("default.qubit.jax", wires=2) + + @qml.qnode(dev) + def circuit(params): + qml.RX(params[0], wires=0) + qml.RY(params[1], wires=1) + return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1))) + + params = jnp.array([0.5, 0.5]) + res = split_non_commuting(circuit)(params) + res = jnp.concatenate([res[0]] + [jnp.stack(res[1:])], axis=0) + + grad = jax.jacobian(split_non_commuting(circuit))(params) + grad = jnp.concatenate([grad[0]] + [jnp.stack(grad[1:])], axis=0) + + assert all(np.isclose(res, exp_res_probs, atol=0.05)) + assert all(np.isclose(grad, exp_grad_probs, atol=0.05).flatten()) @pytest.mark.jax def test_split_with_jax_multi_params(self): @@ -358,8 +639,8 @@ def circuit(x, y): x = jnp.array(0.5) y = jnp.array(0.5) - res = circuit(x, y) - grad = jax.jacobian(circuit, argnums=[0, 1])(x, y) + res = split_non_commuting(circuit)(x, y) + grad = jax.jacobian(split_non_commuting(circuit), argnums=[0, 1])(x, y) assert all(np.isclose(res, exp_res)) @@ -373,6 +654,49 @@ def circuit(x, y): assert np.allclose(meas_grad, exp_grad[i], atol=1e-5) + @pytest.mark.jax + def test_split_with_jax_multi_params_probs(self): + """Test that results after splitting are still differentiable with jax + with multiple parameters""" + + import jax + import jax.numpy as jnp + + dev = qml.device("default.qubit.jax", wires=2) + + @qml.qnode(dev) + def circuit(x, y): + qml.RX(x, wires=0) + qml.RY(y, wires=1) + return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1))) + + x = jnp.array(0.5) + y = jnp.array(0.5) + + res = split_non_commuting(circuit)(x, y) + res = jnp.concatenate([res[0]] + [jnp.stack(res[1:])], axis=0) + assert all(np.isclose(res, exp_res_probs)) + + grad = jax.jacobian(split_non_commuting(circuit), argnums=[0, 1])(x, y) + + assert isinstance(grad, tuple) + assert len(grad) == 2 + + for meas_grad in grad: + assert isinstance(meas_grad, tuple) + assert len(meas_grad) == 2 + assert all(isinstance(g, jnp.ndarray) for g in meas_grad) + + # reshape the returned gradient to the right shape + grad = jnp.concatenate( + [ + jnp.concatenate([grad[0][0].reshape(-1, 1), grad[0][1].reshape(-1, 1)], axis=1), + jnp.concatenate([grad[1][0].reshape(-1, 1), grad[1][1].reshape(-1, 1)], axis=1), + ], + axis=0, + ) + assert all(np.isclose(grad, exp_grad_probs, atol=0.05).flatten()) + @pytest.mark.jax def test_split_with_jax_jit(self): """Test that results after splitting are still differentiable with jax-jit""" @@ -399,6 +723,32 @@ def circuit(params): assert all(np.isclose(res, exp_res)) assert all(np.isclose(grad, exp_grad, atol=1e-5).flatten()) + @pytest.mark.jax + def test_split_with_jax_jit_probs(self): + """Test resulting after splitting non-commuting tapes with expval and probs measurements + are still differentiable with jax""" + + import jax + import jax.numpy as jnp + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circuit(params): + qml.RX(params[0], wires=0) + qml.RY(params[1], wires=1) + return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1))) + + params = jnp.array([0.5, 0.5]) + res = split_non_commuting(circuit)(params) + res = jnp.concatenate([res[0]] + [jnp.stack(res[1:])], axis=0) + + grad = jax.jacobian(split_non_commuting(circuit))(params) + grad = jnp.concatenate([grad[0]] + [jnp.stack(grad[1:])], axis=0) + + assert all(np.isclose(res, exp_res_probs, atol=0.05)) + assert all(np.isclose(grad, exp_grad_probs, atol=0.05).flatten()) + @pytest.mark.torch def test_split_with_torch(self): """Test that results after splitting are still differentiable with torch""" @@ -419,7 +769,7 @@ def circuit(params): ) def cost(params): - res = circuit(params) + res = split_non_commuting(circuit)(params) return qml.math.stack(res) params = torch.tensor([0.5, 0.5], requires_grad=True) @@ -428,6 +778,32 @@ def cost(params): assert all(np.isclose(res.detach().numpy(), exp_res)) assert all(np.isclose(grad.detach().numpy(), exp_grad, atol=1e-5).flatten()) + @pytest.mark.torch + def test_split_with_torch_probs(self): + """Test resulting after splitting non-commuting tapes with expval and probs measurements + are still differentiable with torch""" + + import torch + from torch.autograd.functional import jacobian + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circuit(params): + qml.RX(params[0], wires=0) + qml.RY(params[1], wires=1) + return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1))) + + def cost(params): + res = split_non_commuting(circuit)(params) + return qml.math.concatenate([res[0]] + [qml.math.stack(res[1:])], axis=0) + + params = torch.tensor([0.5, 0.5], requires_grad=True) + res = cost(params) + grad = jacobian(cost, (params)) + assert all(np.isclose(res.detach().numpy(), exp_res_probs)) + assert all(np.isclose(grad.detach().numpy(), exp_grad_probs, atol=1e-5).flatten()) + @pytest.mark.tf def test_split_with_tf(self): """Test that results after splitting are still differentiable with tf""" @@ -449,9 +825,32 @@ def circuit(params): params = tf.Variable([0.5, 0.5]) res = circuit(params) with tf.GradientTape() as tape: - loss = circuit(params) + loss = split_non_commuting(circuit)(params) loss = tf.stack(loss) grad = tape.jacobian(loss, params) assert all(np.isclose(res, exp_res)) assert all(np.isclose(grad, exp_grad, atol=1e-5).flatten()) + + @pytest.mark.tf + def test_split_with_tf_probs(self): + """Test that results after splitting are still differentiable with tf""" + + import tensorflow as tf + + dev = qml.device("default.qubit.tf", wires=2) + + @qml.qnode(dev) + def circuit(params): + qml.RX(params[0], wires=0) + qml.RY(params[1], wires=1) + return (qml.probs(wires=[0, 1]), qml.expval(qml.PauliX(0) @ qml.PauliX(1))) + + params = tf.Variable([0.5, 0.5]) + with tf.GradientTape() as tape: + res = split_non_commuting(circuit)(params) + res = tf.concat([res[0]] + [tf.stack(res[1:])], axis=0) + + grad = tape.jacobian(res, params) + assert all(np.isclose(res, exp_res_probs)) + assert all(np.isclose(grad, exp_grad_probs, atol=1e-5).flatten())