Skip to content

Commit

Permalink
Automatically test plxpr capture integration with assert_valid (Pen…
Browse files Browse the repository at this point in the history
…nyLaneAI#5686)

**Context:**

Following on from PennyLaneAI#5523, we need a good way of ensuring operations will
integrate well with program capture.

**Description of the Change:**

This PR adds a check to the `assert_valid` function to ensure capture is
working as expected.

It then starts to fix up some operations that were failing to be
captured correctly.

**Benefits:**

More robust integration with pl capture.

**Possible Drawbacks:**

I'm realizing a lot more operators fail to be immediately captured than
I initially assumed 😢

**Related GitHub Issues:**

[sc-63310]

---------

Co-authored-by: dwierichs <[email protected]>
Co-authored-by: Thomas R. Bromley <[email protected]>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent 4ff59f7 commit ddfb019
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 26 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)
[(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708)
[(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523)
[(#5686)](https://github.com/PennyLaneAI/pennylane/pull/5686)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
Expand Down
5 changes: 4 additions & 1 deletion pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ def _primitive_bind_call(cls, *args, **kwargs):
if cls._primitive is None:
# guard against this being called when primitive is not defined.
return type.__call__(cls, *args, **kwargs)

iterable_wires_types = (list, tuple, qml.wires.Wires, range, set)

# process wires so that we can handle them either as a final argument or as a keyword argument.
Expand Down Expand Up @@ -2092,6 +2091,10 @@ def _flatten(self):
def _unflatten(cls, data, _):
return cls(*data)

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args)

def __init__(self, *args): # pylint: disable=super-init-not-called
self._eigvals_cache = None
self.obs: List[Observable] = []
Expand Down
17 changes: 17 additions & 0 deletions pennylane/ops/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ def __init__(self, K_list, wires=None, id=None):
def _flatten(self):
return (self.data,), (self.wires, ())

# pylint: disable=arguments-differ, unused-argument
@classmethod
def _primitive_bind_call(cls, K_list, wires=None, id=None):
return super()._primitive_bind_call(*K_list, wires=wires)

@staticmethod
def compute_kraus_matrices(*kraus_matrices): # pylint:disable=arguments-differ
"""Kraus matrices representing the QubitChannel channel.
Expand All @@ -762,6 +767,18 @@ def compute_kraus_matrices(*kraus_matrices): # pylint:disable=arguments-differ
return list(kraus_matrices)


# The primitive will be None if jax is not installed in the environment
# If defined, we need to update the implementation to repack matrices
# See capture module for more information
if QubitChannel._primitive is not None: # pylint: disable=protected-access

@QubitChannel._primitive.def_impl # pylint: disable=protected-access
def _(*args, n_wires):
K_list = args[:-n_wires]
wires = args[-n_wires:]
return type.__call__(QubitChannel, K_list, wires=wires)


class ThermalRelaxationError(Channel):
r"""
Thermal relaxation error channel.
Expand Down
29 changes: 29 additions & 0 deletions pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,34 @@ def _check_pytree(op):
), f"data must be the terminal leaves of the pytree. Got {d1}, {d2}"


def _check_capture(op):
try:
import jax
except ImportError:
return

if not all(isinstance(w, int) for w in op.wires):
return

qml.capture.enable()
try:
jaxpr = jax.make_jaxpr(lambda obj: obj)(op)
data, _ = jax.tree_util.tree_flatten(op)
new_op = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *data)[0]
assert op == new_op
except Exception as e:
raise ValueError(
(
"The capture of the operation into jaxpr failed somehow."
" This capture mechanism is currently experimental and not a core"
" requirement, but will be necessary in the future."
" Please see the capture module documentation for more information."
)
) from e
finally:
qml.capture.disable()


def _check_pickle(op):
"""Check that an operation can be dumped and reloaded with pickle."""
pickled = pickle.dumps(op)
Expand Down Expand Up @@ -303,3 +331,4 @@ def __init__(self, wires):
_check_matrix(op)
_check_matrix_matches_decomp(op)
_check_eigendecomposition(op)
_check_capture(op)
21 changes: 19 additions & 2 deletions pennylane/ops/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def circuit():
num_params = 0
grad_method = None

@classmethod
def _primitive_bind_call(cls, tag=None, measurement=None):
if measurement is None:
return cls._primitive.bind(measurement=measurement, tag=tag)
return cls._primitive.bind(measurement, tag=tag)

def __init__(self, tag=None, measurement=None):
self.tag = tag
if measurement:
Expand All @@ -212,17 +218,18 @@ def __init__(self, tag=None, measurement=None):
f"an instance of {qml.measurements.StateMeasurement}"
)
self.hyperparameters["measurement"] = measurement
self.hyperparameters["tag"] = tag
super().__init__(wires=[])

def label(self, decimals=None, base_label=None, cache=None):
return "|Snap|"

def _flatten(self):
return (), (self.tag, self.hyperparameters["measurement"])
return (self.hyperparameters["measurement"],), (self.tag,)

@classmethod
def _unflatten(cls, data, metadata):
return cls(tag=metadata[0], measurement=metadata[1])
return cls(tag=metadata[0], measurement=data[0])

# pylint: disable=W0613
@staticmethod
Expand All @@ -234,3 +241,13 @@ def _controlled(self, _):

def adjoint(self):
return Snapshot(tag=self.tag)


# Since measurements are captured as variables in plxpr with the capture module,
# the measurement is treated as a traceable argument.
# This step is mandatory for fixing the order of arguments overwritten by ``Snapshot._primitive_bind_call``.
if Snapshot._primitive: # pylint: disable=protected-access

@Snapshot._primitive.def_impl # pylint: disable=protected-access
def _(measurement, tag=None):
return type.__call__(Snapshot, tag=tag, measurement=measurement)
93 changes: 77 additions & 16 deletions pennylane/ops/op_math/controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
This submodule contains controlled operators based on the ControlledOp class.
"""
# pylint: disable=no-value-for-parameter
# pylint: disable=no-value-for-parameter, arguments-differ, arguments-renamed
import warnings
from functools import lru_cache
from typing import Iterable
Expand Down Expand Up @@ -210,6 +210,10 @@ def _flatten(self):
def _unflatten(cls, data, metadata):
return cls(metadata[0])

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=2)

def __init__(self, wires, id=None):
control_wires = wires[:1]
target_wires = wires[1:]
Expand Down Expand Up @@ -320,6 +324,10 @@ def _flatten(self):
def _unflatten(cls, data, metadata):
return cls(metadata[0])

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=2)

def __init__(self, wires, id=None):
control_wire, wire = wires
super().__init__(qml.Y(wire), control_wire, id=id)
Expand Down Expand Up @@ -423,6 +431,10 @@ def _flatten(self):
def _unflatten(cls, data, metadata):
return cls(metadata[0])

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=2)

def __init__(self, wires, id=None):
control_wire, wire = wires
super().__init__(qml.Z(wires=wire), control_wire, id=id)
Expand Down Expand Up @@ -505,6 +517,10 @@ def _flatten(self):
def _unflatten(cls, data, metadata):
return cls(metadata[0])

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=3)

def __init__(self, wires, id=None):
control_wires = wires[:1]
target_wires = wires[1:]
Expand Down Expand Up @@ -607,6 +623,10 @@ class CCZ(ControlledOp):
wires (Sequence[int]): the subsystem the gate acts on
"""

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=3)

def _flatten(self):
return tuple(), (self.wires,)

Expand Down Expand Up @@ -768,6 +788,10 @@ def _flatten(self):
def _unflatten(cls, data, metadata):
return cls(metadata[0])

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=2)

def __init__(self, wires, id=None):
control_wire, wire = wires
super().__init__(qml.PauliX(wires=wire), control_wire, id=id)
Expand Down Expand Up @@ -850,6 +874,10 @@ def _flatten(self):
def _unflatten(cls, _, metadata):
return cls(metadata[0])

@classmethod
def _primitive_bind_call(cls, wires, id=None):
return cls._primitive.bind(*wires, n_wires=3)

def __init__(self, wires, id=None):
control_wires = wires[:2]
target_wires = wires[2:]
Expand Down Expand Up @@ -1035,6 +1063,13 @@ def _flatten(self):
def _unflatten(cls, _, metadata):
return cls(wires=metadata[0], control_values=metadata[1], work_wires=metadata[2])

# pylint: disable=arguments-differ
@classmethod
def _primitive_bind_call(cls, wires, control_values=None, work_wires=None, id=None):
return cls._primitive.bind(
*wires, n_wires=len(wires), control_values=control_values, work_wires=work_wires
)

# pylint: disable=too-many-arguments
def __init__(self, control_wires=None, wires=None, control_values=None, work_wires=None):

Expand Down Expand Up @@ -1250,11 +1285,16 @@ def __init__(self, phi, wires, id=None):
def __repr__(self):
return f"CRX({self.data[0]}, wires={self.wires.tolist()})"

def _flatten(self):
return self.data, (self.wires,)

@classmethod
def _unflatten(cls, data, metadata):
base = data[0]
control_wires = metadata[0]
return cls(*base.data, wires=control_wires + base.wires)
return cls(*data, wires=metadata[0])

@classmethod
def _primitive_bind_call(cls, phi, wires, id=None):
return cls._primitive.bind(phi, *wires, n_wires=len(wires))

@staticmethod
def compute_matrix(theta): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -1398,11 +1438,16 @@ def __init__(self, phi, wires, id=None):
def __repr__(self):
return f"CRY({self.data[0]}, wires={self.wires.tolist()}))"

def _flatten(self):
return self.data, (self.wires,)

@classmethod
def _unflatten(cls, data, metadata):
base = data[0]
control_wires = metadata[0]
return cls(*base.data, wires=control_wires + base.wires)
return cls(*data, wires=metadata[0])

@classmethod
def _primitive_bind_call(cls, phi, wires, id=None):
return cls._primitive.bind(phi, *wires, n_wires=len(wires))

@staticmethod
def compute_matrix(theta): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -1546,11 +1591,16 @@ def __init__(self, phi, wires, id=None):
def __repr__(self):
return f"CRZ({self.data[0]}, wires={self.wires})"

def _flatten(self):
return self.data, (self.wires,)

@classmethod
def _unflatten(cls, data, metadata):
base = data[0]
control_wires = metadata[0]
return cls(*base.data, wires=control_wires + base.wires)
return cls(*data, wires=metadata[0])

@classmethod
def _primitive_bind_call(cls, phi, wires, id=None):
return cls._primitive.bind(phi, *wires, n_wires=len(wires))

@staticmethod
def compute_matrix(theta): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -1727,11 +1777,17 @@ def __repr__(self):
params = ", ".join([repr(p) for p in self.parameters])
return f"CRot({params}, wires={self.wires})"

def _flatten(self):
return self.data, (self.wires,)

@classmethod
def _unflatten(cls, data, metadata):
base = data[0]
control_wires = metadata[0]
return cls(*base.data, wires=control_wires + base.wires)
return cls(*data, wires=metadata[0])

# pylint: disable=too-many-arguments
@classmethod
def _primitive_bind_call(cls, phi, theta, omega, wires, id=None):
return cls._primitive.bind(phi, theta, omega, *wires, n_wires=len(wires))

@staticmethod
def compute_matrix(phi, theta, omega): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -1886,11 +1942,16 @@ def __init__(self, phi, wires, id=None):
def __repr__(self):
return f"ControlledPhaseShift({self.data[0]}, wires={self.wires})"

def _flatten(self):
return self.data, (self.wires,)

@classmethod
def _unflatten(cls, data, metadata):
base = data[0]
control_wires = metadata[0]
return cls(*base.data, wires=control_wires + base.wires)
return cls(*data, wires=metadata[0])

@classmethod
def _primitive_bind_call(cls, phi, wires, id=None):
return cls._primitive.bind(phi, *wires, n_wires=len(wires))

@staticmethod
def compute_matrix(phi): # pylint: disable=arguments-differ
Expand Down
14 changes: 14 additions & 0 deletions pennylane/ops/op_math/linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def _flatten(self):
def _unflatten(cls, data, metadata):
return cls(data[0], data[1], _grouping_indices=metadata[0])

# pylint: disable=arguments-differ
@classmethod
def _primitive_bind_call(cls, coeffs, observables, _pauli_rep=None, **kwargs):
return cls._primitive.bind(*coeffs, *observables, **kwargs, n_obs=len(observables))

def __init__(
self,
coeffs,
Expand Down Expand Up @@ -568,3 +573,12 @@ def map_wires(self, wire_map: dict):
new_op = LinearCombination(coeffs, new_ops)
new_op.grouping_indices = self._grouping_indices
return new_op


if LinearCombination._primitive is not None:

@LinearCombination._primitive.def_impl
def _(*args, n_obs, **kwargs):
coeffs = args[:n_obs]
observables = args[n_obs:]
return type.__call__(LinearCombination, coeffs, observables, **kwargs)
Loading

0 comments on commit ddfb019

Please sign in to comment.