Skip to content

Commit

Permalink
Add QNode config for mid-circuit measurement options (PennyLaneAI#5679)
Browse files Browse the repository at this point in the history
**Context:**
This PR adds qnode arguments to configure mid-circuit measurement
behaviour

**Description of the Change:**
* Added `postselect_mode` and `mcm_method` kwargs to qnode.
* `postselect_mode` is a string, and setting it to `"fill-shots"` will
return all samples regardless of validity. Setting it to `"hw-like"`
will scale shots with postselection. Note that `"hw-like"` with jax will
replace invalid samples with `INTEGER_MIN_VAL` with
`mcm_method="one-shot"`. An error will be raised with
`defer_measurements` with `postselect_mode="hw-like"`` and jax jit.
* `mcm_method` is a string and can be either `"deferred"` or
`"one-shot"`
* Update `qml.devices.preprocess.mid_circuit_measurements` to
accommodate mcm configuration options when deciding which transform to
use.
* Update `QNode._execution_component` to use the
`qml.devices.preprocess.mid_circuit_measurements` transform for old API
devices.
* Update `ExecutionConfig` to include an `mcm_config` and update `QNode`
and `qml.execute` to set this config.
* Added new section with details about the kwargs to the measurements
intro doc.

Note:
When using jax-jit, `postselect_mode="hw-like"` will add dummy values to
the samples with `dynamic_one_shot` and these won't be used for MPs
other than `qml.sample`. However, with `defer_measurements`, an error
will be raised. This is a limitation of the current implementation of
`defer_measurements`, and I've documented it as such

**Benefits:**
Users can easily configure how to apply and process mid-circuit
measurements.

**Possible Drawbacks:**
More kwargs to `QNode` and `qml.execute`.

**Related GitHub Issues:**

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
mudit2812 and albi3ro authored Jun 4, 2024
1 parent c14ce7b commit 4ff59f7
Show file tree
Hide file tree
Showing 27 changed files with 763 additions and 203 deletions.
85 changes: 85 additions & 0 deletions doc/introduction/measurements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ outcome of such mid-circuit measurements:
qml.cond(m_0, qml.RY)(y, wires=0)
return qml.probs(wires=[0])
.. _deferred_measurements:

Deferred measurements
*********************

Expand Down Expand Up @@ -312,6 +314,8 @@ tensor([0.90165331, 0.09834669], requires_grad=True)
and quantum hardware. The one-shot transform below does not have this limitation, but has
computational cost that scales with the number of shots used.

.. _one_shot_transform:

The one-shot transform
**********************

Expand Down Expand Up @@ -524,6 +528,87 @@ Collecting statistics for sequences of mid-circuit measurements is supported wit
When collecting statistics for a list of mid-circuit measurements, values manipulated using
arithmetic operators should not be used as this behaviour is not supported.

Configuring mid-circuit measurements
************************************

As seen above, there are multiple ways in which circuits with mid-circuit measurements can be executed with
PennyLane. For ease of use, we provide the following configuration options to users when initializing a
:class:`~pennylane.QNode`:

* ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"``
to use the :ref:`deferred measurements principle <deferred_measurements>` or ``mcm_method="one-shot"`` to use
the :ref:`one-shot transform <one_shot_transform>`. When executing with finite shots, ``mcm_method="one-shot"``
will be the default, and ``mcm_method="deferred"`` otherwise.

.. warning::

If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or
:func:`~pennylane.dynamic_one_shot` transforms must not be applied directly to the :class:`~pennylane.QNode`
as it can lead to incorrect behaviour.

* ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements
with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid samples. In this case, the number
of samples that are used for processing results can be less than the total number of shots. If
``postselect_mode="fill-shots"`` is used, then the postselected value will be sampled unconditionally, and all
samples will be valid. This is equivalent to sampling until the number of valid samples matches the total number
of shots. The default behaviour is ``postselect_mode="hw-like"``.

.. code-block:: python3
import pennylane as qml
import numpy as np
dev = qml.device("default.qubit", wires=3, shots=10)
def circuit(x):
qml.RX(x, 0)
m0 = qml.measure(0, postselect=1)
qml.CNOT([0, 1])
return qml.sample(qml.PauliZ(0))
fill_shots_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="fill-shots")
hw_like_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="hw-like")
>>> fill_shots_qnode(np.pi / 2)
array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])
>>> hw_like_qnode(np.pi / 2)
array([-1., -1., -1., -1., -1., -1., -1.])

.. note::

When using the ``jax`` interface, ``postselect_mode="hw-like"`` will have different behaviour based on the
chosen ``mcm_method``.

* If ``mcm_method="one-shot"``, invalid shots will not be discarded. Instead, invalid samples will be replaced
by ``np.iinfo(np.int32).min``. These invalid samples will not be used for processing final results (like
expectation values), but will appear in the ``QNode`` output if samples are requested directly. Consider
the circuit below:

.. code-block:: python3
import pennylane as qml
import jax
import jax.numpy as jnp
dev = qml.device("default.qubit", wires=3, shots=10, seed=jax.random.PRNGKey(123))
@qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot")
def circuit(x):
qml.RX(x, 0)
qml.measure(0, postselect=1)
return qml.sample(qml.PauliZ(0))
>>> x = jnp.array(1.8)
>>> f(x)
Array([-2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -2.1474836e+09,
-1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09,
-1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True)

* When using ``jax.jit``, using ``mcm_method="deferred"`` is not supported with ``postselect_mode="hw-like"`` and
an error will be raised if this configuration is requested. This is due to limitations of the
:func:`~pennylane.defer_measurements` transform, and this behaviour will change in the future to be more
consistent with ``mcm_method="one-shot"``.

Changing the number of shots
----------------------------

Expand Down
15 changes: 14 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@

<h3>New features since last release</h3>

* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`.
These keyword arguments can be used to configure how the device should behave when running circuits with
mid-circuit measurements.
[(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679)

* `postselect_mode="hw-like"` will indicate to devices to discard invalid shots when postselecting
mid-circuit measurements. Use `postselect_mode="fill-shots"` to unconditionally sample the postselected
value, thus making all samples valid. This is equivalent to sampling until the number of valid samples
matches the total number of shots.
* `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements.
Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"`
to execute once for each shot.

* The `default.tensor` device is introduced to perform tensor network simulation of a quantum circuit.
[(#5699)](https://github.com/PennyLaneAI/pennylane/pull/5699)

Expand Down Expand Up @@ -40,7 +53,7 @@

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)

* When using `defer_measurements` with postselecting mid-circuit measurements, operations
that will never be active due to the postselected state are skipped in the transformed
quantum circuit. In addition, postselected controls are skipped, as they are evaluated
Expand Down
7 changes: 5 additions & 2 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _multi_meas_with_counts_shot_vec(self, circuit: QuantumTape, shot_tuple, r):

return new_r

def batch_execute(self, circuits):
def batch_execute(self, circuits, **kwargs):
"""Execute a batch of quantum circuits on the device.
The circuits are represented by tapes, and they are executed one-by-one using the
Expand All @@ -492,13 +492,16 @@ def batch_execute(self, circuits):
),
)

if self.capabilities().get("supports_mid_measure", False):
kwargs.setdefault("postselect_mode", None)

results = []
for circuit in circuits:
# we need to reset the device here, else it will
# not start the next computation in the zero state
self.reset()

res = self.execute(circuit)
res = self.execute(circuit, **kwargs)
results.append(res)

if self.tracker.active:
Expand Down
5 changes: 4 additions & 1 deletion pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This submodule defines a capture compatible call to QNodes.
"""

from copy import copy
from functools import lru_cache, partial

import pennylane as qml
Expand Down Expand Up @@ -159,7 +160,9 @@ def f(x):
qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func

qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args)
qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs}
execute_kwargs = copy(qnode.execute_kwargs)
mcm_config = execute_kwargs.pop("mcm_config")
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

return qnode_prim.bind(
Expand Down
3 changes: 2 additions & 1 deletion pennylane/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
:toctree: api
ExecutionConfig
MCMConfig
Device
DefaultQubit
NullQubit
Expand Down Expand Up @@ -146,7 +147,7 @@ def execute(self, circuits, execution_config = qml.devices.DefaultExecutionConfi
"""

from .execution_config import ExecutionConfig, DefaultExecutionConfig
from .execution_config import ExecutionConfig, DefaultExecutionConfig, MCMConfig
from .device_api import Device
from .default_qubit import DefaultQubit

Expand Down
17 changes: 15 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ def preprocess(
transform_program = TransformProgram()

transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(mid_circuit_measurements, device=self)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
)
transform_program.add_transform(
decompose,
stopping_condition=stopping_condition,
Expand Down Expand Up @@ -597,14 +599,22 @@ def execute(
"interface": interface,
"state_cache": self._state_cache,
"prng_key": _key,
"postselect_mode": execution_config.mcm_config.postselect_mode,
},
)
for c, _key in zip(circuits, prng_keys)
)

vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))
simulate_kwargs = [{"rng": _rng, "prng_key": _key} for _rng, _key in zip(seeds, prng_keys)]
simulate_kwargs = [
{
"rng": _rng,
"prng_key": _key,
"postselect_mode": execution_config.mcm_config.postselect_mode,
}
for _rng, _key in zip(seeds, prng_keys)
]

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
exec_map = executor.map(_simulate_wrapper, vanilla_circuits, simulate_kwargs)
Expand Down Expand Up @@ -848,20 +858,23 @@ def _simulate_wrapper(circuit, kwargs):


def _adjoint_jac_wrapper(c, debugger=None):
c = c.map_to_standard_wires()
state, is_state_batched = get_final_state(c, debugger=debugger)
jac = adjoint_jacobian(c, state=state)
res = measure_final_state(c, state, is_state_batched)
return res, jac


def _adjoint_jvp_wrapper(c, t, debugger=None):
c = c.map_to_standard_wires()
state, is_state_batched = get_final_state(c, debugger=debugger)
jvp = adjoint_jvp(c, t, state=state)
res = measure_final_state(c, state, is_state_batched)
return res, jvp


def _adjoint_vjp_wrapper(c, t, debugger=None):
c = c.map_to_standard_wires()
state, is_state_batched = get_final_state(c, debugger=debugger)
vjp = adjoint_vjp(c, t, state=state)
res = measure_final_state(c, state, is_state_batched)
Expand Down
38 changes: 37 additions & 1 deletion pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,39 @@
Contains the :class:`ExecutionConfig` data class.
"""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

from pennylane.workflow import SUPPORTED_INTERFACES


@dataclass
class MCMConfig:
"""A class to store mid-circuit measurement configurations."""

mcm_method: Optional[str] = None
"""Which mid-circuit measurement strategy to use. Use ``deferred`` for the deferred
measurements principle and "one-shot" if using finite shots to execute the circuit
for each shot separately. If not specified, the device will decide which method to
use."""

postselect_mode: Optional[str] = None
"""Configuration for handling shots with mid-circuit measurement postselection. If
``"hw-like"``, invalid shots will be discarded and only results for valid shots will
be returned. If ``"fill-shots"``, results corresponding to the original number of
shots will be returned. If not specified, the device will decide which mode to use."""

def __post_init__(self):
"""
Validate the configured mid-circuit measurement options.
Note that this hook is automatically called after init via the dataclass integration.
"""
if self.mcm_method not in ("deferred", "one-shot", None):
raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.")
if self.postselect_mode not in ("hw-like", "fill-shots", None):
raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.")


# pylint: disable=too-many-instance-attributes
@dataclass
class ExecutionConfig:
Expand Down Expand Up @@ -67,6 +95,9 @@ class ExecutionConfig:
derivative_order: int = 1
"""The derivative order to compute while evaluating a gradient"""

mcm_config: Union[MCMConfig, dict] = MCMConfig()
"""Configuration options for handling mid-circuit measurements"""

def __post_init__(self):
"""
Validate the configured execution options.
Expand All @@ -89,5 +120,10 @@ def __post_init__(self):
if self.gradient_keyword_arguments is None:
self.gradient_keyword_arguments = {}

if isinstance(self.mcm_config, dict):
self.mcm_config = MCMConfig(**self.mcm_config)
elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")


DefaultExecutionConfig = ExecutionConfig()
Loading

0 comments on commit 4ff59f7

Please sign in to comment.