Skip to content

Commit

Permalink
TransformDispatcher can dispatch onto a batch of tapes (PennyLaneAI#5163
Browse files Browse the repository at this point in the history
)

**Context:** A primary goal for the transform project is that transforms
are composable. With the current features, transforms are composable on
the qnode, but they need to be more robust to compose when working with
(batches of) tapes.

**Description of the Change:** A new private method has been added to
the `TransformDispatcher` class to dispatch transforms onto batches of
tapes. Such a method is essentially the copy of `map_batch_transform`.

**Benefits:** The proposed solution would lower cognitive overhead for
developers working with transforms.

**Possible Drawbacks:** The new method is called only if the class
object is a sequence with all instances of `qml.tape.QuantumScript`.
Otherwise, a previously implemented `TransformError` is raised.

**Related GitHub Issues:** None.

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
PietropaoloFrisoni and albi3ro committed Feb 14, 2024
1 parent 274d560 commit ab8c499
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@

<h3>Deprecations 👋</h3>

* `TransformDispatcher` can now dispatch onto a batch of tapes, so that it is easier to compose transforms
when working in the tape paradigm.
[(#5163)](https://github.com/PennyLaneAI/pennylane/pull/5163)

* `Operator.validate_subspace(subspace)` has been relocated to the `qml.ops.qutrit.parametric_ops`
module and will be removed from the Operator class in an upcoming release.
[(#5067)](https://github.com/PennyLaneAI/pennylane/pull/5067)
Expand Down Expand Up @@ -517,6 +521,7 @@ Skylar Chan,
Isaac De Vlugt,
Diksha Dhawan,
Lillian Frederiksen,
Pietropaolo Frisoni,
Eugenio Gigante,
Diego Guala,
David Ittah,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def map_batch_transform(
tape_counts.append(len(new_tapes))

def processing_fn(res: ResultBatch) -> ResultBatch:
"""Applies a batch of post-procesing functions to results.
"""Applies a batch of post-processing functions to results.
Args:
res (ResultBatch): the results of executing a batch of circuits
Expand Down
47 changes: 47 additions & 0 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import warnings
import types

from typing import Sequence

import pennylane as qml
from pennylane.typing import ResultBatch


class TransformError(Exception):
Expand Down Expand Up @@ -121,6 +124,8 @@ def processing_fn(results):
return self._device_transform(obj, targs, tkwargs)
if callable(obj):
return self._qfunc_transform(obj, targs, tkwargs)
if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj):
return self._batch_transform(obj, targs, tkwargs)

# Input is not a QNode nor a quantum tape nor a device.
# Assume Python decorator syntax:
Expand Down Expand Up @@ -313,6 +318,48 @@ def original_device(self):

return TransformedDevice(original_device, self._transform)

def _batch_transform(self, original_batch, targs, tkwargs):
"""Apply the transform on a batch of tapes"""
execution_tapes = []
batch_fns = []
tape_counts = []

for t in original_batch:
# Preprocess the tapes by applying batch transforms
# to each tape, and storing corresponding tapes
# for execution, processing functions, and list of tape lengths.
new_tapes, fn = self(t, *targs, **tkwargs)
execution_tapes.extend(new_tapes)
batch_fns.append(fn)
tape_counts.append(len(new_tapes))

def processing_fn(res: ResultBatch) -> ResultBatch:
"""Applies a batch of post-processing functions to results.
Args:
res (ResultBatch): the results of executing a batch of circuits
Returns:
ResultBatch : results that have undergone classical post processing
Closure variables:
tape_counts: the number of tapes outputted from each application of the transform
batch_fns: the post processing functions to apply to each sub-batch
"""
count = 0
final_results = []

for f, s in zip(batch_fns, tape_counts):
# apply any batch transform post-processing
new_res = f(res[count : count + s])
final_results.append(new_res)
count += s

return tuple(final_results)

return tuple(execution_tapes), processing_fn


class TransformContainer:
"""Class to store a quantum transform with its ``args``, ``kwargs`` and classical co-transforms. Use
Expand Down
61 changes: 61 additions & 0 deletions tests/transforms/test_experimental/test_transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
import pennylane as qml
from pennylane.transforms.core import transform, TransformError, TransformContainer
from pennylane.typing import TensorLike

dev = qml.device("default.qubit", wires=2)

Expand Down Expand Up @@ -470,6 +471,66 @@ def test_dispatched_transform_attribute(self):
assert dispatched_transform.expand_transform is None
assert dispatched_transform.classical_cotransform is None

@pytest.mark.parametrize("valid_transform", valid_transforms)
@pytest.mark.parametrize("batch_type", (tuple, list))
def test_batch_transform(self, valid_transform, batch_type, num_margin=1e-8):
"""Test that dispatcher can dispatch onto a batch of tapes."""

def check_batch(batch):
return isinstance(batch, Sequence) and all(
isinstance(tape, qml.tape.QuantumScript) for tape in batch
)

def comb_postproc(results: TensorLike, fn1: Callable, fn2: Callable):
return fn1(fn2(results))

# Create a simple device and tape
tmp_dev = qml.device("default.qubit", wires=3)
H = qml.PauliY(2) @ qml.PauliZ(1) + 0.5 * qml.PauliZ(2) + qml.PauliZ(1)
measur = [qml.expval(H)]
ops = [qml.Hadamard(0), qml.RX(0.2, 0), qml.RX(0.6, 0), qml.CNOT((0, 1))]
tape = qml.tape.QuantumTape(ops, measur)

############################################################
### Test with two elementary user-defined transforms
############################################################

dispatched_transform1 = transform(valid_transform)
dispatched_transform2 = transform(valid_transform)

batch1, fn1 = dispatched_transform1(tape, index=0)
assert check_batch(batch1)

batch2, fn2 = dispatched_transform2(batch1, index=0)
assert check_batch(batch2)

result = tmp_dev.execute(batch2)
assert isinstance(result, TensorLike)

############################################################
### Test with two `concrete` transforms
############################################################

tape = qml.tape.QuantumTape(ops, measur)

batch1, fn1 = qml.transforms.hamiltonian_expand(tape)
assert check_batch(batch1)

batch2, fn2 = qml.transforms.merge_rotations(batch1)
assert check_batch(batch2)

result = tmp_dev.execute(batch2)
assert isinstance(result, TensorLike)

# check that final batch and post-processing functions are what we expect after the two transforms
fin_ops = [qml.Hadamard(0), qml.RX(0.8, 0), qml.CNOT([0, 1])]
tp1 = qml.tape.QuantumTape(fin_ops, [qml.expval(qml.PauliZ(2)), qml.expval(qml.PauliZ(1))])
tp2 = qml.tape.QuantumTape(fin_ops, [qml.expval(qml.PauliY(2) @ qml.PauliZ(1))])
fin_batch = batch_type([tp1, tp2])

assert all(qml.equal(tapeA, tapeB) for tapeA, tapeB in zip(fin_batch, batch2))
assert abs(comb_postproc(result, fn1, fn2).item() - 0.5) < num_margin

@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_custom_qnode_transform(self, valid_transform):
"""Test that the custom qnode transform is correctly executed"""
Expand Down
35 changes: 35 additions & 0 deletions tests/transforms/test_split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,41 @@ def test_mixed_wires_obs_measurement_types(self, meas_type_1, meas_type_2):
assert qml.equal(split[1].measurements[0], meas_type_2())
assert qml.equal(split[1].measurements[1], meas_type_2(wires=[0, 1]))

@pytest.mark.parametrize("batch_type", (tuple, list))
def test_batch_of_tapes(self, batch_type):
"""Test that `split_non_commuting` can transform a batch of tapes"""

# create a batch with two simple tapes
tape1 = qml.tape.QuantumScript(
[qml.RX(1.2, 0)], [qml.expval(qml.X(0)), qml.expval(qml.Y(0)), qml.expval(qml.X(1))]
)
tape2 = qml.tape.QuantumScript(
[qml.RY(0.5, 0)], [qml.expval(qml.Z(0)), qml.expval(qml.Y(0))]
)
batch = batch_type([tape1, tape2])

# test transform on the batch
new_batch, post_proc_fn = split_non_commuting(batch)

# test that transform has been applied correctly on the batch by explicitly comparing with splitted tapes
tp1 = qml.tape.QuantumScript([qml.RX(1.2, 0)], [qml.expval(qml.X(0)), qml.expval(qml.X(1))])
tp2 = qml.tape.QuantumScript([qml.RX(1.2, 0)], [qml.expval(qml.Y(0))])
tp3 = qml.tape.QuantumScript([qml.RY(0.5, 0)], [qml.expval(qml.Z(0))])
tp4 = qml.tape.QuantumScript([qml.RY(0.5, 0)], [qml.expval(qml.Y(0))])

assert all(qml.equal(tapeA, tapeB) for tapeA, tapeB in zip(new_batch, [tp1, tp2, tp3, tp4]))

# test postprocessing function applied to the transformed batch
assert all(
qml.equal(tapeA, tapeB)
for sublist1, sublist2 in zip(post_proc_fn(new_batch), ((tp1, tp2), (tp3, tp4)))
for tapeA, tapeB in zip(sublist1, sublist2)
)

# final (double) check: test postprocessing function on a fictitious results
result = ("tp1", "tp2", "tp3", "tp4")
assert post_proc_fn(result) == (("tp1", "tp2"), ("tp3", "tp4"))


# measurements that require shots=True
required_shot_meas_fn = [qml.sample, qml.counts]
Expand Down

0 comments on commit ab8c499

Please sign in to comment.