Skip to content

Commit

Permalink
Capture the ctrl transform into jaxpr (PennyLaneAI#5967)
Browse files Browse the repository at this point in the history
**Context:**

We have an ongoing experimental project to enable capture of quantum
functions into a jaxpr representation.

Following on from PennyLaneAI#5966 , this PR adds the ability to capture the `ctrl`
transform into jaxpr.

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-68090]

---------

Co-authored-by: Thomas R. Bromley <[email protected]>
Co-authored-by: Pietropaolo Frisoni <[email protected]>
Co-authored-by: David Ittah <[email protected]>
Co-authored-by: David Wierichs <[email protected]>
  • Loading branch information
5 people committed Jul 30, 2024
1 parent f9adf90 commit a5e21ce
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 2 deletions.
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@
* `QuantumScript.hash` is now cached, leading to performance improvements.
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* Applying `adjoint` to a quantum function can now be captured into plxpr.
* Applying `adjoint` and `ctrl` to a quantum function can now be captured into plxpr.
[(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966)
[(#5967)](https://github.com/PennyLaneAI/pennylane/pull/5967)

* Set operations are now supported by Wires.
[(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983)
Expand Down
77 changes: 76 additions & 1 deletion pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from copy import copy
from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional, Sequence, overload

import numpy as np
from scipy import sparse
Expand All @@ -34,7 +35,21 @@
from .symbolicop import SymbolicOp


def ctrl(op, control, control_values=None, work_wires=None):
@overload
def ctrl(
op: Operator,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = None,
) -> Operator: ...
@overload
def ctrl(
op: Callable,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = None,
) -> Callable: ...
def ctrl(op, control: Any, control_values=None, work_wires=None):
"""Create a method that applies a controlled version of the provided op.
:func:`~.qjit` compatible.
Expand Down Expand Up @@ -177,7 +192,12 @@ def create_controlled_op(op, control, control_values=None, work_wires=None):
"This error might occur if you apply ctrl to a list "
"of operations instead of a function or Operator."
)
if qml.capture.enabled():
return _capture_ctrl_transform(op, control, control_values, work_wires)
return _ctrl_transform(op, control, control_values, work_wires)


def _ctrl_transform(op, control, control_values, work_wires):
@wraps(op)
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)
Expand All @@ -204,6 +224,61 @@ def wrapper(*args, **kwargs):
return wrapper


@functools.lru_cache # only create the first time requested
def _get_ctrl_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel

ctrl_prim = jax.core.Primitive("ctrl_transform")
ctrl_prim.multiple_results = True

@ctrl_prim.def_impl
def _(*args, n_control, jaxpr, control_values, work_wires, n_consts):
consts = args[:n_consts]
control_wires = args[-n_control:]
args = args[n_consts:-n_control]

with qml.queuing.AnnotatedQueue() as q:
jax.core.eval_jaxpr(jaxpr, consts, *args)
ops, _ = qml.queuing.process_queue(q)

for op in ops:
ctrl(op, control_wires, control_values, work_wires)
return []

@ctrl_prim.def_abstract_eval
def _(*_, **__):
return []

return ctrl_prim


def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable:
"""Capture compatible way of performing an ctrl transform."""
# note that this logic is tested in `tests/capture/test_nested_plxpr.py`
import jax # pylint: disable=import-outside-toplevel

ctrl_prim = _get_ctrl_qfunc_prim()

@wraps(qfunc)
def new_qfunc(*args, **kwargs):
jaxpr = jax.make_jaxpr(functools.partial(qfunc, **kwargs))(*args)
control_wires = qml.wires.Wires(control) # make sure is iterable
ctrl_prim.bind(
*jaxpr.consts,
*args,
*control_wires,
jaxpr=jaxpr.jaxpr,
n_control=len(control_wires),
control_values=control_values,
work_wires=work_wires,
n_consts=len(jaxpr.consts),
)

return new_qfunc


@functools.lru_cache()
def _get_special_ops():
"""Gets a list of special operations with custom controlled versions.
Expand Down
145 changes: 145 additions & 0 deletions tests/capture/test_nested_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

import pennylane as qml
from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim
from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim

pytestmark = pytest.mark.jax

jax = pytest.importorskip("jax")

adjoint_prim = _get_adjoint_qfunc_prim()
ctrl_prim = _get_ctrl_qfunc_prim()


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -160,3 +162,146 @@ def qfunc(wire): # x is closure variable and a tracer

assert len(q) == 1
qml.assert_equal(q.queue[0], qml.adjoint(qml.RX(2.5, 2)))


class TestCtrlQfunc:
"""Tests for the ctrl primitive."""

def test_operator_type_input(self):
"""Test that an operator type can be the callable."""

def f(x, w):
return qml.ctrl(qml.RX, 1)(x, w)

plxpr = jax.make_jaxpr(f)(0.5, 0)

with qml.queuing.AnnotatedQueue() as q:
out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1.2, 2)

assert f(0.5, 0) is None
assert out == []
expected = qml.ctrl(qml.RX(1.2, 2), 1)
qml.assert_equal(q.queue[0], expected)

assert plxpr.eqns[0].primitive == ctrl_prim
assert plxpr.eqns[0].params["control_values"] == [True]
assert plxpr.eqns[0].params["n_control"] == 1
assert plxpr.eqns[0].params["work_wires"] is None
assert plxpr.eqns[0].params["n_consts"] == 0

def test_dynamic_control_wires(self):
"""Test that control wires can be dynamic."""

def f(w1, w2, w3):
return qml.ctrl(qml.X, (w2, w3))(w1)

plxpr = jax.make_jaxpr(f)(4, 5, 6)

with qml.queuing.AnnotatedQueue() as q:
out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1, 2, 3)

assert out == []
expected = qml.Toffoli(wires=(2, 3, 1))
qml.assert_equal(q.queue[0], expected)
assert len(q) == 1

assert plxpr.eqns[0].primitive == ctrl_prim
assert plxpr.eqns[0].params["control_values"] == [True, True]
assert plxpr.eqns[0].params["n_control"] == 2
assert plxpr.eqns[0].params["work_wires"] is None

def test_work_wires(self):
"""Test that work wires can be provided."""

def f(w):
return qml.ctrl(qml.S, (1, 2), work_wires="aux")(w)

plxpr = jax.make_jaxpr(f)(6)

with qml.queuing.AnnotatedQueue() as q:
out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 5)

assert out == []
expected = qml.ctrl(qml.S(5), (1, 2), work_wires="aux")
qml.assert_equal(q.queue[0], expected)
assert len(q) == 1

assert plxpr.eqns[0].params["work_wires"] == "aux"

def test_control_values(self):
"""Test that control values can be provided."""

def f(z):
return qml.ctrl(qml.RZ, (3, 4), [False, True])(z, 0)

plxpr = jax.make_jaxpr(f)(0.5)

with qml.queuing.AnnotatedQueue() as q:
out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 5.4)

assert out == []
expected = qml.ctrl(qml.RZ(5.4, 0), (3, 4), [False, True])
qml.assert_equal(q.queue[0], expected)
assert len(q) == 1

assert plxpr.eqns[0].params["control_values"] == [False, True]
assert plxpr.eqns[0].params["n_control"] == 2

def test_nested_control(self):
"""Test that control can be nested."""

def f(x, w1, w2):
f1 = qml.ctrl(qml.Rot, w1)
return qml.ctrl(f1, w2)(x, 0.5, 2 * x, 0)

plxpr = jax.make_jaxpr(f)(-0.5, 1, 2)

# First equation of plxpr is the multiplication of x by 2
assert plxpr.eqns[1].params["n_consts"] == 1 # w1 is a const for the outer `ctrl`
assert (
plxpr.eqns[1].invars[0] is plxpr.jaxpr.invars[1]
) # first input is first control wire, const
assert plxpr.eqns[1].invars[1] is plxpr.jaxpr.invars[0] # second input is x, first arg
assert plxpr.eqns[1].invars[-1] is plxpr.jaxpr.invars[2] # second control wire
assert len(plxpr.eqns[1].invars) == 6 # one const, 4 args, one control wire

with qml.queuing.AnnotatedQueue() as q:
jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1.2, 3, 4)

target = qml.Rot(1.2, 0.5, jax.numpy.array(2 * 1.2), wires=0)
expected = qml.ctrl(qml.ctrl(target, 3), 4)
qml.assert_equal(q.queue[0], expected)

@pytest.mark.parametrize("include_s", (True, False))
def test_extended_qfunc(self, include_s):
"""Test that the qfunc can contain multiple operations and classical processing."""

def qfunc(x, wire, include_s=True):
qml.RX(2 * x, wire)
qml.RY(x + 1, wire + 1)
if include_s:
qml.S(wire)

def workflow(wire):
qml.ctrl(qfunc, 0)(0.5, wire, include_s=include_s)

jaxpr = jax.make_jaxpr(workflow)(1)

with qml.queuing.AnnotatedQueue() as q:
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2)

expected0 = qml.ctrl(qml.RX(jax.numpy.array(1.0), 2), 0)
expected1 = qml.ctrl(qml.RY(jax.numpy.array(1.5), 3), 0)
assert len(q.queue) == 2 + include_s
qml.assert_equal(q.queue[0], expected0)
qml.assert_equal(q.queue[1], expected1)
if include_s:
qml.assert_equal(q.queue[2], qml.ctrl(qml.S(2), 0))

eqn = jaxpr.eqns[0]
assert eqn.params["control_values"] == [True]
assert eqn.params["n_consts"] == 0
assert eqn.params["n_control"] == 1
assert eqn.params["work_wires"] is None

assert len(eqn.params["jaxpr"].eqns) == 5 + include_s

0 comments on commit a5e21ce

Please sign in to comment.