diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 56a8d5741ad..b403126a94e 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -70,8 +70,9 @@ * `QuantumScript.hash` is now cached, leading to performance improvements. [(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919) -* Applying `adjoint` to a quantum function can now be captured into plxpr. +* Applying `adjoint` and `ctrl` to a quantum function can now be captured into plxpr. [(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966) + [(#5967)](https://github.com/PennyLaneAI/pennylane/pull/5967) * Set operations are now supported by Wires. [(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983) diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index b065d99df74..a7c4b1a9935 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -19,6 +19,7 @@ from copy import copy from functools import wraps from inspect import signature +from typing import Any, Callable, Optional, Sequence, overload import numpy as np from scipy import sparse @@ -34,7 +35,21 @@ from .symbolicop import SymbolicOp -def ctrl(op, control, control_values=None, work_wires=None): +@overload +def ctrl( + op: Operator, + control: Any, + control_values: Optional[Sequence[bool]] = None, + work_wires: Optional[Any] = None, +) -> Operator: ... +@overload +def ctrl( + op: Callable, + control: Any, + control_values: Optional[Sequence[bool]] = None, + work_wires: Optional[Any] = None, +) -> Callable: ... +def ctrl(op, control: Any, control_values=None, work_wires=None): """Create a method that applies a controlled version of the provided op. :func:`~.qjit` compatible. @@ -177,7 +192,12 @@ def create_controlled_op(op, control, control_values=None, work_wires=None): "This error might occur if you apply ctrl to a list " "of operations instead of a function or Operator." ) + if qml.capture.enabled(): + return _capture_ctrl_transform(op, control, control_values, work_wires) + return _ctrl_transform(op, control, control_values, work_wires) + +def _ctrl_transform(op, control, control_values, work_wires): @wraps(op) def wrapper(*args, **kwargs): qscript = qml.tape.make_qscript(op)(*args, **kwargs) @@ -204,6 +224,61 @@ def wrapper(*args, **kwargs): return wrapper +@functools.lru_cache # only create the first time requested +def _get_ctrl_qfunc_prim(): + """See capture/explanations.md : Higher Order primitives for more information on this code.""" + # if capture is enabled, jax should be installed + import jax # pylint: disable=import-outside-toplevel + + ctrl_prim = jax.core.Primitive("ctrl_transform") + ctrl_prim.multiple_results = True + + @ctrl_prim.def_impl + def _(*args, n_control, jaxpr, control_values, work_wires, n_consts): + consts = args[:n_consts] + control_wires = args[-n_control:] + args = args[n_consts:-n_control] + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr, consts, *args) + ops, _ = qml.queuing.process_queue(q) + + for op in ops: + ctrl(op, control_wires, control_values, work_wires) + return [] + + @ctrl_prim.def_abstract_eval + def _(*_, **__): + return [] + + return ctrl_prim + + +def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable: + """Capture compatible way of performing an ctrl transform.""" + # note that this logic is tested in `tests/capture/test_nested_plxpr.py` + import jax # pylint: disable=import-outside-toplevel + + ctrl_prim = _get_ctrl_qfunc_prim() + + @wraps(qfunc) + def new_qfunc(*args, **kwargs): + jaxpr = jax.make_jaxpr(functools.partial(qfunc, **kwargs))(*args) + control_wires = qml.wires.Wires(control) # make sure is iterable + ctrl_prim.bind( + *jaxpr.consts, + *args, + *control_wires, + jaxpr=jaxpr.jaxpr, + n_control=len(control_wires), + control_values=control_values, + work_wires=work_wires, + n_consts=len(jaxpr.consts), + ) + + return new_qfunc + + @functools.lru_cache() def _get_special_ops(): """Gets a list of special operations with custom controlled versions. diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index 9dfa867d1a6..c2b8ca09c6a 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -19,12 +19,14 @@ import pennylane as qml from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim +from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim pytestmark = pytest.mark.jax jax = pytest.importorskip("jax") adjoint_prim = _get_adjoint_qfunc_prim() +ctrl_prim = _get_ctrl_qfunc_prim() @pytest.fixture(autouse=True) @@ -160,3 +162,146 @@ def qfunc(wire): # x is closure variable and a tracer assert len(q) == 1 qml.assert_equal(q.queue[0], qml.adjoint(qml.RX(2.5, 2))) + + +class TestCtrlQfunc: + """Tests for the ctrl primitive.""" + + def test_operator_type_input(self): + """Test that an operator type can be the callable.""" + + def f(x, w): + return qml.ctrl(qml.RX, 1)(x, w) + + plxpr = jax.make_jaxpr(f)(0.5, 0) + + with qml.queuing.AnnotatedQueue() as q: + out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1.2, 2) + + assert f(0.5, 0) is None + assert out == [] + expected = qml.ctrl(qml.RX(1.2, 2), 1) + qml.assert_equal(q.queue[0], expected) + + assert plxpr.eqns[0].primitive == ctrl_prim + assert plxpr.eqns[0].params["control_values"] == [True] + assert plxpr.eqns[0].params["n_control"] == 1 + assert plxpr.eqns[0].params["work_wires"] is None + assert plxpr.eqns[0].params["n_consts"] == 0 + + def test_dynamic_control_wires(self): + """Test that control wires can be dynamic.""" + + def f(w1, w2, w3): + return qml.ctrl(qml.X, (w2, w3))(w1) + + plxpr = jax.make_jaxpr(f)(4, 5, 6) + + with qml.queuing.AnnotatedQueue() as q: + out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1, 2, 3) + + assert out == [] + expected = qml.Toffoli(wires=(2, 3, 1)) + qml.assert_equal(q.queue[0], expected) + assert len(q) == 1 + + assert plxpr.eqns[0].primitive == ctrl_prim + assert plxpr.eqns[0].params["control_values"] == [True, True] + assert plxpr.eqns[0].params["n_control"] == 2 + assert plxpr.eqns[0].params["work_wires"] is None + + def test_work_wires(self): + """Test that work wires can be provided.""" + + def f(w): + return qml.ctrl(qml.S, (1, 2), work_wires="aux")(w) + + plxpr = jax.make_jaxpr(f)(6) + + with qml.queuing.AnnotatedQueue() as q: + out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 5) + + assert out == [] + expected = qml.ctrl(qml.S(5), (1, 2), work_wires="aux") + qml.assert_equal(q.queue[0], expected) + assert len(q) == 1 + + assert plxpr.eqns[0].params["work_wires"] == "aux" + + def test_control_values(self): + """Test that control values can be provided.""" + + def f(z): + return qml.ctrl(qml.RZ, (3, 4), [False, True])(z, 0) + + plxpr = jax.make_jaxpr(f)(0.5) + + with qml.queuing.AnnotatedQueue() as q: + out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 5.4) + + assert out == [] + expected = qml.ctrl(qml.RZ(5.4, 0), (3, 4), [False, True]) + qml.assert_equal(q.queue[0], expected) + assert len(q) == 1 + + assert plxpr.eqns[0].params["control_values"] == [False, True] + assert plxpr.eqns[0].params["n_control"] == 2 + + def test_nested_control(self): + """Test that control can be nested.""" + + def f(x, w1, w2): + f1 = qml.ctrl(qml.Rot, w1) + return qml.ctrl(f1, w2)(x, 0.5, 2 * x, 0) + + plxpr = jax.make_jaxpr(f)(-0.5, 1, 2) + + # First equation of plxpr is the multiplication of x by 2 + assert plxpr.eqns[1].params["n_consts"] == 1 # w1 is a const for the outer `ctrl` + assert ( + plxpr.eqns[1].invars[0] is plxpr.jaxpr.invars[1] + ) # first input is first control wire, const + assert plxpr.eqns[1].invars[1] is plxpr.jaxpr.invars[0] # second input is x, first arg + assert plxpr.eqns[1].invars[-1] is plxpr.jaxpr.invars[2] # second control wire + assert len(plxpr.eqns[1].invars) == 6 # one const, 4 args, one control wire + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1.2, 3, 4) + + target = qml.Rot(1.2, 0.5, jax.numpy.array(2 * 1.2), wires=0) + expected = qml.ctrl(qml.ctrl(target, 3), 4) + qml.assert_equal(q.queue[0], expected) + + @pytest.mark.parametrize("include_s", (True, False)) + def test_extended_qfunc(self, include_s): + """Test that the qfunc can contain multiple operations and classical processing.""" + + def qfunc(x, wire, include_s=True): + qml.RX(2 * x, wire) + qml.RY(x + 1, wire + 1) + if include_s: + qml.S(wire) + + def workflow(wire): + qml.ctrl(qfunc, 0)(0.5, wire, include_s=include_s) + + jaxpr = jax.make_jaxpr(workflow)(1) + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2) + + expected0 = qml.ctrl(qml.RX(jax.numpy.array(1.0), 2), 0) + expected1 = qml.ctrl(qml.RY(jax.numpy.array(1.5), 3), 0) + assert len(q.queue) == 2 + include_s + qml.assert_equal(q.queue[0], expected0) + qml.assert_equal(q.queue[1], expected1) + if include_s: + qml.assert_equal(q.queue[2], qml.ctrl(qml.S(2), 0)) + + eqn = jaxpr.eqns[0] + assert eqn.params["control_values"] == [True] + assert eqn.params["n_consts"] == 0 + assert eqn.params["n_control"] == 1 + assert eqn.params["work_wires"] is None + + assert len(eqn.params["jaxpr"].eqns) == 5 + include_s