Skip to content

Commit

Permalink
Add generator checks and differentiability checks to assert_valid (#6282
Browse files Browse the repository at this point in the history
)

The purpose of this story is to check if every operator can be
differentiated properly with the parameter shift method. The gradient
calculated using parameter shift is compared with backprop to verify
that they are equal.

There are some cases where this does not apply:
1. Backprop does not produce the correct gradient in some cases such as
for `StatePrep` and for `QubitUnitary`, as it does not take into account
the constraint that the matrix must remain unitary, making this test
invalid.
2. Some operator takes integers as parameters, such as `BasisState`. In
this case, it does not make sense to take the gradient of integer
parameters.

For these cases, a `skip_differentiation` toggle is added to
`assert_valid` such that the differentiation check is skipped for these
operators.

Three bugs are found as a result of adding this check. The relevant
tests are xfailed:
#6331
#6333
#6340

Some other minor bug fixes are also included in this PR.

[sc-65197]
Fixes #6311
  • Loading branch information
astralcai authored and austingmhuang committed Oct 23, 2024
1 parent cd7d941 commit ca7eb65
Show file tree
Hide file tree
Showing 26 changed files with 164 additions and 60 deletions.
2 changes: 1 addition & 1 deletion pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ def generator(self): # pylint: disable=no-self-use
0.5 * Y(0) + Z(0) @ X(1)
The generator may also be provided in the form of a dense or sparse Hamiltonian
(using :class:`.Hermitian` and :class:`.SparseHamiltonian` respectively).
(using :class:`.Hamiltonian` and :class:`.SparseHamiltonian` respectively).
The default value to return is ``None``, indicating that the operation has
no defined generator.
Expand Down
64 changes: 63 additions & 1 deletion pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,25 @@ def _check_eigendecomposition(op):
assert qml.math.allclose(decomp_mat, original_mat), failure_comment


def _check_generator(op):
"""Checks that if an operator's has_generator property is True, it has a generator."""

if op.has_generator:
gen = op.generator()
assert isinstance(gen, qml.operation.Operator)
new_op = qml.exp(gen, 1j * op.data[0])
assert qml.math.allclose(
qml.matrix(op, wire_order=op.wires), qml.matrix(new_op, wire_order=op.wires)
)
else:
failure_comment = (
"If has_generator is False, the matrix method must raise a ``GeneratorUndefinedError``."
)
_assert_error_raised(
op.generator, qml.operation.GeneratorUndefinedError, failure_comment=failure_comment
)()


def _check_copy(op):
"""Check that copies and deep copies give identical objects."""
copied_op = copy.copy(op)
Expand Down Expand Up @@ -276,6 +295,39 @@ def _check_bind_new_parameters(op):
assert qml.math.allclose(d1, d2), failure_comment


def _check_differentiation(op):
"""Checks that the operator can be executed and differentiated correctly."""

if op.num_params == 0:
return

data, struct = qml.pytrees.flatten(op)

def circuit(*args):
qml.apply(qml.pytrees.unflatten(args, struct))
return qml.probs(wires=op.wires)

qnode_ref = qml.QNode(circuit, qml.device("default.qubit"), diff_method="backprop")
qnode_ps = qml.QNode(circuit, qml.device("default.qubit"), diff_method="parameter-shift")

params = [x if isinstance(x, int) else qml.numpy.array(x) for x in data]

ps = qml.jacobian(qnode_ps)(*params)
expected_bp = qml.jacobian(qnode_ref)(*params)

error_msg = (
"Parameter-shift does not produce the same Jacobian as with backpropagation. "
"This might be a bug, or it might be expected due to the mathematical nature "
"of backpropagation, in which case, this test can be skipped for this operator."
)

if isinstance(ps, tuple):
for actual, expected in zip(ps, expected_bp):
assert qml.math.allclose(actual, expected), error_msg
else:
assert qml.math.allclose(ps, expected_bp), error_msg


def _check_wires(op, skip_wire_mapping):
"""Check that wires are a ``Wires`` class and can be mapped."""
assert isinstance(op.wires, qml.wires.Wires), "wires must be a wires instance"
Expand All @@ -288,7 +340,12 @@ def _check_wires(op, skip_wire_mapping):
assert mapped_op.wires == new_wires, "wires must be mappable with map_wires"


def assert_valid(op: qml.operation.Operator, skip_pickle=False, skip_wire_mapping=False) -> None:
def assert_valid(
op: qml.operation.Operator,
skip_pickle=False,
skip_wire_mapping=False,
skip_differentiation=False,
) -> None:
"""Runs basic validation checks on an :class:`~.operation.Operator` to make
sure it has been correctly defined.
Expand All @@ -298,6 +355,8 @@ def assert_valid(op: qml.operation.Operator, skip_pickle=False, skip_wire_mappin
Keyword Args:
skip_pickle=False : If ``True``, pickling tests are not run. Set to ``True`` when
testing a locally defined operator, as pickle cannot handle local objects
skip_differentiation: If ``True``, differentiation tests are not run. Set to `True` when
the operator is parametrized but not differentiable.
**Examples:**
Expand Down Expand Up @@ -352,4 +411,7 @@ def __init__(self, wires):
_check_matrix_matches_decomp(op)
_check_sparse_matrix(op)
_check_eigendecomposition(op)
_check_generator(op)
if not skip_differentiation:
_check_differentiation(op)
_check_capture(op)
2 changes: 1 addition & 1 deletion pennylane/ops/op_math/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def generator(self):
f"The operator coefficient {self.coeff} is not imaginary; the expected format is exp(-ixG)."
f"The generator is not defined."
)
return self.base
return -1 * self.base

def __copy__(self):
copied = super().__copy__()
Expand Down
1 change: 0 additions & 1 deletion pennylane/ops/op_math/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def __init__(self, base, coeff=1, num_steps=None, id=None):
super().__init__(base, scalar=coeff, id=id)
self.grad_recipe = [None]
self.num_steps = num_steps

self.hyperparameters["num_steps"] = num_steps

def __repr__(self):
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def circuit(weights):

_op_symbol = "@"
_math_op = math.prod
grad_method = None

@property
def is_hermitian(self):
Expand Down Expand Up @@ -359,7 +360,7 @@ def arithmetic_depth(self) -> int:
def _build_pauli_rep(self):
"""PauliSentence representation of the Product of operations."""
if all(operand_pauli_reps := [op.pauli_rep for op in self.operands]):
return reduce(lambda a, b: a @ b, operand_pauli_reps)
return reduce(lambda a, b: a @ b, operand_pauli_reps) if operand_pauli_reps else None
return None

def _simplify_factors(self, factors: tuple[Operator]) -> tuple[complex, Operator]:
Expand Down
1 change: 1 addition & 0 deletions pennylane/templates/subroutines/prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
def _get_new_terms(lcu):
"""Compute a new sum of unitaries with positive coefficients"""
coeffs, ops = lcu.terms()
coeffs = qml.math.stack(coeffs)
angles = qml.math.angle(coeffs)
new_ops = []

Expand Down
79 changes: 46 additions & 33 deletions tests/ops/functions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,55 @@
from pennylane.ops.op_math.pow import PowObs, PowOperation, PowOpObs

_INSTANCES_TO_TEST = [
qml.sum(qml.PauliX(0), qml.PauliZ(0)),
qml.sum(qml.X(0), qml.X(0), qml.Z(0), qml.Z(0)),
qml.BasisState([1], wires=[0]),
qml.ControlledQubitUnitary(np.eye(2), control_wires=1, wires=0),
qml.QubitChannel([np.array([[1, 0], [0, 0.8]]), np.array([[0, 0.6], [0, 0]])], wires=0),
qml.MultiControlledX(wires=[0, 1]),
qml.Projector([1], 0), # the state-vector version is already tested
qml.SpecialUnitary([1, 1, 1], 0),
qml.IntegerComparator(1, wires=[0, 1]),
qml.PauliRot(1.1, "X", wires=[0]),
qml.StatePrep([0, 1], 0),
qml.PCPhase(0.27, dim=2, wires=[0, 1]),
qml.BlockEncode([[0.1, 0.2], [0.3, 0.4]], wires=[0, 1]),
qml.adjoint(qml.PauliX(0)),
qml.adjoint(qml.RX(1.1, 0)),
Tensor(qml.PauliX(0), qml.PauliX(1)),
qml.ops.LinearCombination([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)]),
qml.s_prod(1.1, qml.RX(1.1, 0)),
qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)),
qml.ctrl(qml.RX(1.1, 0), 1),
qml.exp(qml.PauliX(0), 1.1),
qml.pow(qml.IsingXX(1.1, [0, 1]), 2.5),
qml.ops.Evolution(qml.PauliX(0), 5.2),
qml.QutritBasisState([1, 2, 0], wires=[0, 1, 2]),
qml.resource.FirstQuantization(1, 2, 1),
qml.prod(qml.RX(1.1, 0), qml.RY(2.2, 0), qml.RZ(3.3, 1)),
qml.Snapshot(measurement=qml.expval(qml.Z(0)), tag="hi"),
qml.Snapshot(tag="tag"),
(qml.sum(qml.PauliX(0), qml.PauliZ(0)), {}),
(qml.sum(qml.X(0), qml.X(0), qml.Z(0), qml.Z(0)), {}),
(qml.BasisState([1], wires=[0]), {"skip_differentiation": True}),
(
qml.ControlledQubitUnitary(np.eye(2), control_wires=1, wires=0),
{"skip_differentiation": True},
),
(
qml.QubitChannel([np.array([[1, 0], [0, 0.8]]), np.array([[0, 0.6], [0, 0]])], wires=0),
{"skip_differentiation": True},
),
(qml.MultiControlledX(wires=[0, 1]), {}),
(qml.Projector([1], 0), {"skip_differentiation": True}),
(qml.Projector([1, 0], 0), {"skip_differentiation": True}),
(qml.DiagonalQubitUnitary([1, 1, 1, 1], wires=[0, 1]), {"skip_differentiation": True}),
(qml.QubitUnitary(np.eye(2), wires=[0]), {"skip_differentiation": True}),
(qml.SpecialUnitary([1, 1, 1], 0), {"skip_differentiation": True}),
(qml.IntegerComparator(1, wires=[0, 1]), {"skip_differentiation": True}),
(qml.PauliRot(1.1, "X", wires=[0]), {}),
(qml.StatePrep([0, 1], 0), {"skip_differentiation": True}),
(qml.PCPhase(0.27, dim=2, wires=[0, 1]), {}),
(qml.BlockEncode([[0.1, 0.2], [0.3, 0.4]], wires=[0, 1]), {"skip_differentiation": True}),
(qml.adjoint(qml.PauliX(0)), {}),
(qml.adjoint(qml.RX(1.1, 0)), {}),
(Tensor(qml.PauliX(0), qml.PauliX(1)), {}),
(qml.ops.LinearCombination([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)]), {}),
(qml.s_prod(1.1, qml.RX(1.1, 0)), {}),
(qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)), {}),
(qml.ctrl(qml.RX(1.1, 0), 1), {}),
(qml.exp(qml.PauliX(0), 1.1), {}),
(qml.pow(qml.IsingXX(1.1, [0, 1]), 2.5), {}),
(qml.ops.Evolution(qml.PauliX(0), 5.2), {}),
(qml.QutritBasisState([1, 2, 0], wires=[0, 1, 2]), {"skip_differentiation": True}),
(qml.resource.FirstQuantization(1, 2, 1), {}),
(qml.prod(qml.RX(1.1, 0), qml.RY(2.2, 0), qml.RZ(3.3, 1)), {}),
(qml.Snapshot(measurement=qml.expval(qml.Z(0)), tag="hi"), {}),
(qml.Snapshot(tag="tag"), {}),
]
"""Valid operator instances that could not be auto-generated."""

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "qml.ops.Hamiltonian uses", qml.PennyLaneDeprecationWarning)
_INSTANCES_TO_TEST.append(
qml.operation.convert_to_legacy_H(
qml.Hamiltonian([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)])
),
(
qml.operation.convert_to_legacy_H(
qml.Hamiltonian([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)])
),
{},
)
)


Expand Down Expand Up @@ -130,6 +142,7 @@
Operation,
Observable,
Channel,
qml.ops.Projector,
qml.ops.SymbolicOp,
qml.ops.ScalarSymbolicOp,
qml.ops.Pow,
Expand Down Expand Up @@ -164,7 +177,7 @@ def get_all_classes(c):
_CLASSES_TO_TEST = (
set(get_all_classes(Operator))
- {i[1] for i in getmembers(qml.templates) if isclass(i[1]) and issubclass(i[1], Operator)}
- {type(op) for op in _INSTANCES_TO_TEST}
- {type(op) for (op, _) in _INSTANCES_TO_TEST}
- {type(op) for (op, _) in _INSTANCES_TO_FAIL}
)
"""All operators, except those tested manually, abstract/meta classes, and templates."""
Expand All @@ -176,7 +189,7 @@ def class_to_validate(request):


@pytest.fixture(params=_INSTANCES_TO_TEST)
def valid_instance(request):
def valid_instance_and_kwargs(request):
yield request.param


Expand Down
30 changes: 27 additions & 3 deletions tests/ops/functions/test_assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pennylane as qml
from pennylane.operation import Operator
from pennylane.ops.functions import assert_valid
from pennylane.ops.functions.assert_valid import _check_capture


class TestDecompositionErrors:
Expand Down Expand Up @@ -303,6 +304,28 @@ def _unflatten(cls, data, _):
assert_valid(op, skip_pickle=True)


@pytest.mark.jax
def test_bad_capture():
"""Tests that the correct error is raised when something goes wrong with program capture."""

class MyBadOp(qml.operation.Operator):

def _flatten(self):
return (self.hyperparameters["target_op"], self.data[0]), ()

@classmethod
def _unflatten(cls, data, metadata):
return cls(*data)

def __init__(self, target_op, val):
super().__init__(val, wires=target_op.wires)
self.hyperparameters["target_op"] = target_op

op = MyBadOp(qml.X(0), 2)
with pytest.raises(ValueError, match=r"The capture of the operation into jaxpr failed"):
_check_capture(op)


def test_data_is_tuple():
"""Check that the data property is a tuple."""

Expand Down Expand Up @@ -376,13 +399,14 @@ def test_generated_list_of_ops(class_to_validate, str_wires):


@pytest.mark.jax
def test_explicit_list_of_ops(valid_instance):
def test_explicit_list_of_ops(valid_instance_and_kwargs):
"""Test the validity of operators that could not be auto-generated."""
valid_instance, kwargs = valid_instance_and_kwargs
if valid_instance.name == "Hamiltonian":
with qml.operation.disable_new_opmath_cm(warn=False):
assert_valid(valid_instance)
assert_valid(valid_instance, **kwargs)
else:
assert_valid(valid_instance)
assert_valid(valid_instance, **kwargs)


@pytest.mark.jax
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/op_math/test_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_has_generator_false(self):

def test_generator(self):
U = Evolution(qml.PauliX(0), 3)
assert U.base == U.generator()
assert U.generator() == -1 * U.base

@pytest.mark.usefixtures("legacy_opmath_only")
def test_num_params_for_parametric_base_legacy_opmath(self):
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_generator_not_observable_class(self, base):
"""Test that qml.generator will return generator if it is_hermitian, but is not a subclass of Observable"""
op = Evolution(base, 1)
gen, c = qml.generator(op)
qml.assert_equal(gen if c == 1 else qml.s_prod(c, gen), base)
qml.assert_equal(gen if c == 1 else qml.s_prod(c, gen), -1 * base)

def test_generator_error_if_not_hermitian(self):
"""Tests that an error is raised if the generator is not hermitian."""
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/qubit/test_observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def test_basisstate_projector(self):
second_projector = qml.Projector(basis_state, wires)
qml.assert_equal(second_projector, basis_state_projector)

qml.ops.functions.assert_valid(basis_state_projector)
qml.ops.functions.assert_valid(basis_state_projector, skip_differentiation=True)

def test_statevector_projector(self):
"""Test that we obtain a _StateVectorProjector when input is a state vector."""
Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_standard_validity():

op = qml.AmplitudeEmbedding(features=FEATURES[0], wires=range(2))

qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True)


class TestDecomposition:
Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def test_standard_validity():
"""Check the operation using the assert_valid function."""
op = qml.AngleEmbedding(features=[1, 2, 3], wires=range(3), rotation="Z")
op = qml.AngleEmbedding(features=[1.0, 2.0, 3.0], wires=range(3), rotation="Z")
qml.ops.functions.assert_valid(op)


Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_standard_validity():
"""Check the operation using the assert_valid function."""
wires = qml.wires.Wires((0, 1, 2))
op = qml.BasisEmbedding(features=np.array([1, 1, 1]), wires=wires)
qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True)


# pylint: disable=protected-access
Expand Down
4 changes: 2 additions & 2 deletions tests/templates/test_embeddings/test_displacement_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

def test_standard_validity():
"""Check the operation using the assert_valid function."""
feature_vector = [1, 2, 3]
feature_vector = [1.0, 2.0, 3.0]
op = qml.DisplacementEmbedding(features=feature_vector, wires=range(3), method="phase", c=0.5)
qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True) # Skip because it's CV op.


def test_flatten_unflatten_methods():
Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_iqp_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def test_standard_validity():
"""Check the operation using the assert_valid function."""
features = (0, 1, 2)
features = (0.0, 1.0, 2.0)

op = qml.IQPEmbedding(features, wires=(0, 1, 2))
qml.ops.functions.assert_valid(op)
Expand Down
Loading

0 comments on commit ca7eb65

Please sign in to comment.