From ab8c4993cd0043ed04109895ef87dfa9d156c759 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 14 Feb 2024 17:33:44 -0500 Subject: [PATCH] TransformDispatcher can dispatch onto a batch of tapes (#5163) **Context:** A primary goal for the transform project is that transforms are composable. With the current features, transforms are composable on the qnode, but they need to be more robust to compose when working with (batches of) tapes. **Description of the Change:** A new private method has been added to the `TransformDispatcher` class to dispatch transforms onto batches of tapes. Such a method is essentially the copy of `map_batch_transform`. **Benefits:** The proposed solution would lower cognitive overhead for developers working with transforms. **Possible Drawbacks:** The new method is called only if the class object is a sequence with all instances of `qml.tape.QuantumScript`. Otherwise, a previously implemented `TransformError` is raised. **Related GitHub Issues:** None. --------- Co-authored-by: Christina Lee --- doc/releases/changelog-dev.md | 5 ++ pennylane/transforms/batch_transform.py | 2 +- .../transforms/core/transform_dispatcher.py | 47 ++++++++++++++ .../test_transform_dispatcher.py | 61 +++++++++++++++++++ tests/transforms/test_split_non_commuting.py | 35 +++++++++++ 5 files changed, 149 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 943c28ae371..96b22233175 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -368,6 +368,10 @@

Deprecations 👋

+* `TransformDispatcher` can now dispatch onto a batch of tapes, so that it is easier to compose transforms + when working in the tape paradigm. + [(#5163)](https://github.com/PennyLaneAI/pennylane/pull/5163) + * `Operator.validate_subspace(subspace)` has been relocated to the `qml.ops.qutrit.parametric_ops` module and will be removed from the Operator class in an upcoming release. [(#5067)](https://github.com/PennyLaneAI/pennylane/pull/5067) @@ -517,6 +521,7 @@ Skylar Chan, Isaac De Vlugt, Diksha Dhawan, Lillian Frederiksen, +Pietropaolo Frisoni, Eugenio Gigante, Diego Guala, David Ittah, diff --git a/pennylane/transforms/batch_transform.py b/pennylane/transforms/batch_transform.py index 97a1d9f7929..e7cc8816c61 100644 --- a/pennylane/transforms/batch_transform.py +++ b/pennylane/transforms/batch_transform.py @@ -510,7 +510,7 @@ def map_batch_transform( tape_counts.append(len(new_tapes)) def processing_fn(res: ResultBatch) -> ResultBatch: - """Applies a batch of post-procesing functions to results. + """Applies a batch of post-processing functions to results. Args: res (ResultBatch): the results of executing a batch of circuits diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index 99a80098c5e..f9b6faeca96 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -20,7 +20,10 @@ import warnings import types +from typing import Sequence + import pennylane as qml +from pennylane.typing import ResultBatch class TransformError(Exception): @@ -121,6 +124,8 @@ def processing_fn(results): return self._device_transform(obj, targs, tkwargs) if callable(obj): return self._qfunc_transform(obj, targs, tkwargs) + if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj): + return self._batch_transform(obj, targs, tkwargs) # Input is not a QNode nor a quantum tape nor a device. # Assume Python decorator syntax: @@ -313,6 +318,48 @@ def original_device(self): return TransformedDevice(original_device, self._transform) + def _batch_transform(self, original_batch, targs, tkwargs): + """Apply the transform on a batch of tapes""" + execution_tapes = [] + batch_fns = [] + tape_counts = [] + + for t in original_batch: + # Preprocess the tapes by applying batch transforms + # to each tape, and storing corresponding tapes + # for execution, processing functions, and list of tape lengths. + new_tapes, fn = self(t, *targs, **tkwargs) + execution_tapes.extend(new_tapes) + batch_fns.append(fn) + tape_counts.append(len(new_tapes)) + + def processing_fn(res: ResultBatch) -> ResultBatch: + """Applies a batch of post-processing functions to results. + + Args: + res (ResultBatch): the results of executing a batch of circuits + + Returns: + ResultBatch : results that have undergone classical post processing + + Closure variables: + tape_counts: the number of tapes outputted from each application of the transform + batch_fns: the post processing functions to apply to each sub-batch + + """ + count = 0 + final_results = [] + + for f, s in zip(batch_fns, tape_counts): + # apply any batch transform post-processing + new_res = f(res[count : count + s]) + final_results.append(new_res) + count += s + + return tuple(final_results) + + return tuple(execution_tapes), processing_fn + class TransformContainer: """Class to store a quantum transform with its ``args``, ``kwargs`` and classical co-transforms. Use diff --git a/tests/transforms/test_experimental/test_transform_dispatcher.py b/tests/transforms/test_experimental/test_transform_dispatcher.py index b1ef83bb65c..959b4b394ba 100644 --- a/tests/transforms/test_experimental/test_transform_dispatcher.py +++ b/tests/transforms/test_experimental/test_transform_dispatcher.py @@ -18,6 +18,7 @@ import pytest import pennylane as qml from pennylane.transforms.core import transform, TransformError, TransformContainer +from pennylane.typing import TensorLike dev = qml.device("default.qubit", wires=2) @@ -470,6 +471,66 @@ def test_dispatched_transform_attribute(self): assert dispatched_transform.expand_transform is None assert dispatched_transform.classical_cotransform is None + @pytest.mark.parametrize("valid_transform", valid_transforms) + @pytest.mark.parametrize("batch_type", (tuple, list)) + def test_batch_transform(self, valid_transform, batch_type, num_margin=1e-8): + """Test that dispatcher can dispatch onto a batch of tapes.""" + + def check_batch(batch): + return isinstance(batch, Sequence) and all( + isinstance(tape, qml.tape.QuantumScript) for tape in batch + ) + + def comb_postproc(results: TensorLike, fn1: Callable, fn2: Callable): + return fn1(fn2(results)) + + # Create a simple device and tape + tmp_dev = qml.device("default.qubit", wires=3) + H = qml.PauliY(2) @ qml.PauliZ(1) + 0.5 * qml.PauliZ(2) + qml.PauliZ(1) + measur = [qml.expval(H)] + ops = [qml.Hadamard(0), qml.RX(0.2, 0), qml.RX(0.6, 0), qml.CNOT((0, 1))] + tape = qml.tape.QuantumTape(ops, measur) + + ############################################################ + ### Test with two elementary user-defined transforms + ############################################################ + + dispatched_transform1 = transform(valid_transform) + dispatched_transform2 = transform(valid_transform) + + batch1, fn1 = dispatched_transform1(tape, index=0) + assert check_batch(batch1) + + batch2, fn2 = dispatched_transform2(batch1, index=0) + assert check_batch(batch2) + + result = tmp_dev.execute(batch2) + assert isinstance(result, TensorLike) + + ############################################################ + ### Test with two `concrete` transforms + ############################################################ + + tape = qml.tape.QuantumTape(ops, measur) + + batch1, fn1 = qml.transforms.hamiltonian_expand(tape) + assert check_batch(batch1) + + batch2, fn2 = qml.transforms.merge_rotations(batch1) + assert check_batch(batch2) + + result = tmp_dev.execute(batch2) + assert isinstance(result, TensorLike) + + # check that final batch and post-processing functions are what we expect after the two transforms + fin_ops = [qml.Hadamard(0), qml.RX(0.8, 0), qml.CNOT([0, 1])] + tp1 = qml.tape.QuantumTape(fin_ops, [qml.expval(qml.PauliZ(2)), qml.expval(qml.PauliZ(1))]) + tp2 = qml.tape.QuantumTape(fin_ops, [qml.expval(qml.PauliY(2) @ qml.PauliZ(1))]) + fin_batch = batch_type([tp1, tp2]) + + assert all(qml.equal(tapeA, tapeB) for tapeA, tapeB in zip(fin_batch, batch2)) + assert abs(comb_postproc(result, fn1, fn2).item() - 0.5) < num_margin + @pytest.mark.parametrize("valid_transform", valid_transforms) def test_custom_qnode_transform(self, valid_transform): """Test that the custom qnode transform is correctly executed""" diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py index 820089fee4f..022e46a8cc9 100644 --- a/tests/transforms/test_split_non_commuting.py +++ b/tests/transforms/test_split_non_commuting.py @@ -293,6 +293,41 @@ def test_mixed_wires_obs_measurement_types(self, meas_type_1, meas_type_2): assert qml.equal(split[1].measurements[0], meas_type_2()) assert qml.equal(split[1].measurements[1], meas_type_2(wires=[0, 1])) + @pytest.mark.parametrize("batch_type", (tuple, list)) + def test_batch_of_tapes(self, batch_type): + """Test that `split_non_commuting` can transform a batch of tapes""" + + # create a batch with two simple tapes + tape1 = qml.tape.QuantumScript( + [qml.RX(1.2, 0)], [qml.expval(qml.X(0)), qml.expval(qml.Y(0)), qml.expval(qml.X(1))] + ) + tape2 = qml.tape.QuantumScript( + [qml.RY(0.5, 0)], [qml.expval(qml.Z(0)), qml.expval(qml.Y(0))] + ) + batch = batch_type([tape1, tape2]) + + # test transform on the batch + new_batch, post_proc_fn = split_non_commuting(batch) + + # test that transform has been applied correctly on the batch by explicitly comparing with splitted tapes + tp1 = qml.tape.QuantumScript([qml.RX(1.2, 0)], [qml.expval(qml.X(0)), qml.expval(qml.X(1))]) + tp2 = qml.tape.QuantumScript([qml.RX(1.2, 0)], [qml.expval(qml.Y(0))]) + tp3 = qml.tape.QuantumScript([qml.RY(0.5, 0)], [qml.expval(qml.Z(0))]) + tp4 = qml.tape.QuantumScript([qml.RY(0.5, 0)], [qml.expval(qml.Y(0))]) + + assert all(qml.equal(tapeA, tapeB) for tapeA, tapeB in zip(new_batch, [tp1, tp2, tp3, tp4])) + + # test postprocessing function applied to the transformed batch + assert all( + qml.equal(tapeA, tapeB) + for sublist1, sublist2 in zip(post_proc_fn(new_batch), ((tp1, tp2), (tp3, tp4))) + for tapeA, tapeB in zip(sublist1, sublist2) + ) + + # final (double) check: test postprocessing function on a fictitious results + result = ("tp1", "tp2", "tp3", "tp4") + assert post_proc_fn(result) == (("tp1", "tp2"), ("tp3", "tp4")) + # measurements that require shots=True required_shot_meas_fn = [qml.sample, qml.counts]