diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 55ab665c1de..8d3f43559c0 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -278,6 +278,10 @@ * A clear error message is added in `KerasLayer` when using the newest version of TensorFlow with Keras 3 (which is not currently compatible with `KerasLayer`), linking to instructions to enable Keras 2. [(#5488)](https://github.com/PennyLaneAI/pennylane/pull/5488) + + * Created the `DefaultQutritMixed` class, which inherits from `qml.devices.Device`, with an implementation + for `preprocess`. + [(#5451)](https://github.com/PennyLaneAI/pennylane/pull/5451) * Removed the warning that an observable might not be hermitian in `qnode` executions. This enables jit-compilation. [(#5506)](https://github.com/PennyLaneAI/pennylane/pull/5506) diff --git a/pennylane/devices/__init__.py b/pennylane/devices/__init__.py index 35125c8fd44..5541760307f 100644 --- a/pennylane/devices/__init__.py +++ b/pennylane/devices/__init__.py @@ -34,6 +34,7 @@ default_gaussian default_mixed default_qutrit + default_qutrit_mixed default_clifford null_qubit tests @@ -55,6 +56,7 @@ Device DefaultQubit NullQubit + DefaultQutritMixed Preprocessing Transforms ------------------------ @@ -155,3 +157,4 @@ def execute(self, circuits, execution_config = qml.devices.DefaultExecutionConfi from .default_mixed import DefaultMixed from .default_clifford import DefaultClifford from .null_qubit import NullQubit +from .default_qutrit_mixed import DefaultQutritMixed diff --git a/pennylane/devices/default_qutrit_mixed.py b/pennylane/devices/default_qutrit_mixed.py new file mode 100644 index 00000000000..e3ee54273d1 --- /dev/null +++ b/pennylane/devices/default_qutrit_mixed.py @@ -0,0 +1,247 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The default.qutrit.mixed device is PennyLane's standard qutrit simulator for mixed-state computations. +""" + +from dataclasses import replace +from typing import Union, Tuple, Sequence +import logging +import numpy as np + +import pennylane as qml +from pennylane.transforms.core import TransformProgram +from pennylane.tape import QuantumTape +from pennylane.typing import Result, ResultBatch + +from . import Device +from .preprocess import ( + decompose, + validate_observables, + validate_measurements, + validate_device_wires, + no_sampling, +) +from .execution_config import ExecutionConfig, DefaultExecutionConfig +from .default_qutrit import DefaultQutrit + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +Result_or_ResultBatch = Union[Result, ResultBatch] +QuantumTapeBatch = Sequence[QuantumTape] +QuantumTape_or_Batch = Union[QuantumTape, QuantumTapeBatch] + +channels = set() + + +def observable_stopping_condition(obs: qml.operation.Operator) -> bool: + """Specifies whether an observable is accepted by DefaultQutritMixed.""" + return obs.name in DefaultQutrit.observables + + +def stopping_condition(op: qml.operation.Operator) -> bool: + """Specify whether an Operator object is supported by the device.""" + expected_set = DefaultQutrit.operations | {"Snapshot"} | channels + return op.name in expected_set + + +def stopping_condition_shots(op: qml.operation.Operator) -> bool: + """Specify whether an Operator object is supported by the device with shots.""" + return stopping_condition(op) + + +def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool: + """Specifies whether a measurement is accepted when sampling.""" + return isinstance(m, qml.measurements.SampleMeasurement) + + +class DefaultQutritMixed(Device): + """A PennyLane device written in Python and capable of backpropagation derivatives. + + Args: + wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that + contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings + (``['ancilla', 'q1', 'q2']``). Default ``None`` if not specified. + shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots + to use in executions involving this device. + seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A + seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``, or + a request to seed from numpy's global random number generator. + The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` + will pull a seed from the OS entropy. + If a ``jax.random.PRNGKey`` is passed as the seed, a JAX-specific sampling function using + ``jax.random.choice`` and the ``PRNGKey`` will be used for sampling rather than + ``numpy.random.default_rng``. + + **Example:** + + .. code-block:: python + + n_wires = 5 + num_qscripts = 5 + qscripts = [] + for i in range(num_qscripts): + unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() + op = qml.QutritUnitary(unitary, wires=range(n_wires)) + qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) + qscripts.append(qs) + + >>> dev = DefaultQutritMixed() + >>> program, execution_config = dev.preprocess() + >>> new_batch, post_processing_fn = program(qscripts) + >>> results = dev.execute(new_batch, execution_config=execution_config) + >>> post_processing_fn(results) + [0.08015701503959313, + 0.04521414211599359, + -0.0215232130089687, + 0.062120285032425865, + -0.0635052317625] + + This device currently supports backpropagation derivatives: + + >>> from pennylane.devices import ExecutionConfig + >>> dev.supports_derivatives(ExecutionConfig(gradient_method="backprop")) + True + + For example, we can use jax to jit computing the derivative: + + .. code-block:: python + + import jax + + @jax.jit + def f(x): + qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) + program, execution_config = dev.preprocess() + new_batch, post_processing_fn = program([qs]) + results = dev.execute(new_batch, execution_config=execution_config) + return post_processing_fn(results) + + >>> f(jax.numpy.array(1.2)) + DeviceArray(0.36235774, dtype=float32) + >>> jax.grad(f)(jax.numpy.array(1.2)) + DeviceArray(-0.93203914, dtype=float32, weak_type=True) + + .. details:: + :title: Tracking + + ``DefaultQutritMixed`` tracks: + + * ``executions``: the number of unique circuits that would be required on quantum hardware + * ``shots``: the number of shots + * ``resources``: the :class:`~.resource.Resources` for the executed circuit. + * ``simulations``: the number of simulations performed. One simulation can cover multiple QPU executions, such as for non-commuting measurements and batched parameters. + * ``batches``: The number of times :meth:`~.execute` is called. + * ``results``: The results of each call of :meth:`~.execute` + + + """ + + _device_options = ("rng", "prng_key") # tuple of string names for all the device options. + + @property + def name(self): + """The name of the device.""" + return "default.qutrit.mixed" + + def __init__( + self, + wires=None, + shots=None, + seed="global", + ) -> None: + super().__init__(wires=wires, shots=shots) + seed = np.random.randint(0, high=10000000) if seed == "global" else seed + if qml.math.get_interface(seed) == "jax": + self._prng_key = seed + self._rng = np.random.default_rng(None) + else: + self._prng_key = None + self._rng = np.random.default_rng(seed) + self._debugger = None + + def _setup_execution_config(self, execution_config: ExecutionConfig) -> ExecutionConfig: + """This is a private helper for ``preprocess`` that sets up the execution config. + + Args: + execution_config (ExecutionConfig) + + Returns: + ExecutionConfig: a preprocessed execution config + """ + updated_values = {} + for option in execution_config.device_options: + if option not in self._device_options: + raise qml.DeviceError(f"device option {option} not present on {self}") + + if execution_config.gradient_method == "best": + updated_values["gradient_method"] = "backprop" + updated_values["use_device_gradient"] = False + updated_values["grad_on_execution"] = False + updated_values["device_options"] = dict(execution_config.device_options) # copy + + for option in self._device_options: + if option not in updated_values["device_options"]: + updated_values["device_options"][option] = getattr(self, f"_{option}") + return replace(execution_config, **updated_values) + + def preprocess( + self, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ) -> Tuple[TransformProgram, ExecutionConfig]: + """This function defines the device transform program to be applied and an updated device configuration. + + Args: + execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure describing the + parameters needed to fully describe the execution. + + Returns: + TransformProgram, ExecutionConfig: A transform program that when called returns QuantumTapes that the device + can natively execute as well as a postprocessing function to be called after execution, and a configuration with + unset specifications filled in. + + This device: + * Supports any qutrit operations that provide a matrix + * Supports any qutrit channel that provides Kraus matrices + """ + config = self._setup_execution_config(execution_config) + transform_program = TransformProgram() + + transform_program.add_transform(validate_device_wires, self.wires, name=self.name) + transform_program.add_transform( + decompose, + stopping_condition=stopping_condition, + stopping_condition_shots=stopping_condition_shots, + name=self.name, + ) + transform_program.add_transform( + validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name + ) + transform_program.add_transform( + validate_observables, stopping_condition=observable_stopping_condition, name=self.name + ) + + if config.gradient_method == "backprop": + transform_program.add_transform(no_sampling, name="backprop + default.qutrit") + + return transform_program, config + + def execute( + self, + circuits: QuantumTape_or_Batch, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ) -> Result_or_ResultBatch: + """Stub for execute.""" + return None diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py new file mode 100644 index 00000000000..9094fc9164d --- /dev/null +++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py @@ -0,0 +1,247 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for qutrit mixed device preprocessing.""" +import pytest + +import numpy as np + +import pennylane as qml +from pennylane.devices import ExecutionConfig +from pennylane.devices.default_qutrit_mixed import DefaultQutritMixed, stopping_condition + + +class NoMatOp(qml.operation.Operation): + """Dummy operation for expanding circuit.""" + + # pylint: disable=arguments-renamed, invalid-overridden-method + @property + def has_matrix(self): + return False + + def decomposition(self): + return [qml.TShift(self.wires), qml.TClock(self.wires)] + + +# pylint: disable=too-few-public-methods +class NoMatNoDecompOp(qml.operation.Operation): + """Dummy operation for checking check_validity throws error when + expected.""" + + # pylint: disable=arguments-renamed, invalid-overridden-method + @property + def has_matrix(self): + return False + + +# pylint: disable=too-few-public-methods +class TestPreprocessing: + """Unit tests for the preprocessing method.""" + + def test_error_if_device_option_not_available(self): + """Test that an error is raised if a device option is requested but not a valid option.""" + dev = DefaultQutritMixed() + + config = ExecutionConfig(device_options={"bla": "val"}) + with pytest.raises(qml.DeviceError, match="device option bla"): + dev.preprocess(config) + + def test_chooses_best_gradient_method(self): + """Test that preprocessing chooses backprop as the best gradient method.""" + dev = DefaultQutritMixed() + + config = ExecutionConfig(gradient_method="best") + + _, new_config = dev.preprocess(config) + + assert new_config.gradient_method == "backprop" + assert not new_config.use_device_gradient + assert not new_config.grad_on_execution + + def test_circuit_wire_validation(self): + """Test that preprocessing validates wires on the circuits being executed.""" + dev = DefaultQutritMixed(wires=3) + + circuit_valid_0 = qml.tape.QuantumScript([qml.TShift(0)]) + program, _ = dev.preprocess() + circuits, _ = program([circuit_valid_0]) + assert circuits[0].circuit == circuit_valid_0.circuit + + circuit_valid_1 = qml.tape.QuantumScript([qml.TShift(1)]) + program, _ = dev.preprocess() + circuits, _ = program([circuit_valid_0, circuit_valid_1]) + assert circuits[0].circuit == circuit_valid_0.circuit + assert circuits[1].circuit == circuit_valid_1.circuit + + invalid_circuit = qml.tape.QuantumScript([qml.TShift(4)]) + program, _ = dev.preprocess() + + with pytest.raises(qml.wires.WireError, match=r"Cannot run circuit\(s\) on"): + program([invalid_circuit]) + + with pytest.raises(qml.wires.WireError, match=r"Cannot run circuit\(s\) on"): + program([circuit_valid_0, invalid_circuit]) + + @pytest.mark.parametrize( + "mp_fn,mp_cls,shots", + [ + (qml.sample, qml.measurements.SampleMP, 10), + (qml.state, qml.measurements.StateMP, None), + (qml.probs, qml.measurements.ProbabilityMP, None), + ], + ) + def test_measurement_is_swapped_out(self, mp_fn, mp_cls, shots): + """Test that preprocessing swaps out any MeasurementProcess with no wires or obs""" + dev = DefaultQutritMixed(wires=3) + original_mp = mp_fn() + exp_z = qml.expval(qml.GellMann(0, 3)) + qs = qml.tape.QuantumScript([qml.THadamard(0)], [original_mp, exp_z], shots=shots) + program, _ = dev.preprocess() + tapes, _ = program([qs]) + assert len(tapes) == 1 + tape = tapes[0] + assert tape.operations == qs.operations + assert tape.measurements != qs.measurements + assert qml.equal(tape.measurements[0], mp_cls(wires=[0, 1, 2])) + assert tape.measurements[1] is exp_z + + @pytest.mark.parametrize( + "op, expected", + [ + (qml.TShift(0), True), + (qml.GellMann(0, 1), False), + (qml.Snapshot(), True), + (qml.TRX(1.1, 0), True), + ], + ) + def test_accepted_operator(self, op, expected): # TODO: Add channel ops once added. + """Test that _accepted_operator works correctly""" + res = stopping_condition(op) + assert res == expected + + +class TestPreprocessingIntegration: + """Test preprocess produces output that can be executed by the device.""" + + def test_batch_transform_no_batching(self): + """Test that batch_transform does nothing when no batching is required.""" + ops = [qml.THadamard(0), qml.TAdd([0, 1]), qml.TRX(0.123, wires=1)] + measurements = [qml.expval(qml.GellMann(1, 3))] + tape = qml.tape.QuantumScript(ops=ops, measurements=measurements) + device = DefaultQutritMixed() + + program, _ = device.preprocess() + tapes, _ = program([tape]) + + assert len(tapes) == 1 + assert tapes[0].circuit == ops + measurements + + def test_batch_transform_broadcast(self): + """Test that batch_transform does nothing when batching is required but + internal PennyLane broadcasting can be used (diff method != adjoint)""" + ops = [qml.THadamard(0), qml.TAdd([0, 1]), qml.TRX([np.pi, np.pi / 2], wires=1)] + measurements = [qml.expval(qml.GellMann(1, 3))] + tape = qml.tape.QuantumScript(ops=ops, measurements=measurements) + device = DefaultQutritMixed() + + program, _ = device.preprocess() + tapes, _ = program([tape]) + + assert len(tapes) == 1 + assert tapes[0].circuit == ops + measurements + + def test_preprocess_batch_transform(self): + """Test that preprocess returns the correct tapes when a batch transform + is needed.""" + ops = [qml.THadamard(0), qml.TAdd([0, 1]), qml.TRX([np.pi, np.pi / 2], wires=1)] + measurements = [qml.expval(qml.GellMann(0, 4)), qml.expval(qml.GellMann(1, 3))] + tapes = [ + qml.tape.QuantumScript(ops=ops, measurements=[measurements[0]]), + qml.tape.QuantumScript(ops=ops, measurements=[measurements[1]]), + ] + + program, _ = DefaultQutritMixed().preprocess() + res_tapes, batch_fn = program(tapes) + + assert len(res_tapes) == 2 + for res_tape, measurement in zip(res_tapes, measurements): + for op, expected_op in zip(res_tape.operations, ops): + assert qml.equal(op, expected_op) + assert res_tape.measurements == [measurement] + + val = ([[1, 2], [3, 4]], [[5, 6], [7, 8]]) + assert np.array_equal(batch_fn(val), np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])) + + def test_preprocess_expand(self): + """Test that preprocess returns the correct tapes when expansion is needed.""" + ops = [qml.THadamard(0), NoMatOp(1), qml.TRZ(0.123, wires=1)] + measurements = [[qml.expval(qml.GellMann(0, 3))], [qml.expval(qml.GellMann(1, 1))]] + tapes = [ + qml.tape.QuantumScript(ops=ops, measurements=measurements[0]), + qml.tape.QuantumScript(ops=ops, measurements=measurements[1]), + ] + + program, _ = DefaultQutritMixed().preprocess() + res_tapes, batch_fn = program(tapes) + + expected = [qml.THadamard(0), qml.TShift(1), qml.TClock(1), qml.TRZ(0.123, wires=1)] + + assert len(res_tapes) == 2 + for i, t in enumerate(res_tapes): + for op, exp in zip(t.circuit, expected + measurements[i]): + assert qml.equal(op, exp) + + val = (("a", "b"), "c", "d") + assert batch_fn(val) == (("a", "b"), "c") + + def test_preprocess_batch_and_expand(self): + """Test that preprocess returns the correct tapes when batching and expanding + is needed.""" + ops = [qml.THadamard(0), NoMatOp(1), qml.TRX([np.pi, np.pi / 2], wires=1)] + measurements = [qml.expval(qml.GellMann(0, 1)), qml.expval(qml.GellMann(1, 3))] + tapes = [ + qml.tape.QuantumScript(ops=ops, measurements=[measurements[0]]), + qml.tape.QuantumScript(ops=ops, measurements=[measurements[1]]), + ] + + program, _ = DefaultQutritMixed().preprocess() + res_tapes, batch_fn = program(tapes) + expected_ops = [ + qml.THadamard(0), + qml.TShift(1), + qml.TClock(1), + qml.TRX([np.pi, np.pi / 2], wires=1), + ] + + assert len(res_tapes) == 2 + for res_tape, measurement in zip(res_tapes, measurements): + for op, expected_op in zip(res_tape.operations, expected_ops): + assert qml.equal(op, expected_op) + assert res_tape.measurements == [measurement] + + val = ([[1, 2], [3, 4]], [[5, 6], [7, 8]]) + assert np.array_equal(batch_fn(val), np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])) + + def test_preprocess_check_validity_fail(self): + """Test that preprocess throws an error if the batched and expanded tapes have + unsupported operators.""" + ops = [qml.THadamard(0), NoMatNoDecompOp(1), qml.TRZ(0.123, wires=1)] + measurements = [[qml.expval(qml.GellMann(0, 3))], [qml.expval(qml.GellMann(1, 1))]] + tapes = [ + qml.tape.QuantumScript(ops=ops, measurements=measurements[0]), + qml.tape.QuantumScript(ops=ops, measurements=measurements[1]), + ] + + program, _ = DefaultQutritMixed().preprocess() + with pytest.raises(qml.DeviceError, match="Operator NoMatNoDecompOp"): + program(tapes) diff --git a/tests/devices/test_default_qutrit_mixed.py b/tests/devices/test_default_qutrit_mixed.py new file mode 100644 index 00000000000..16c7e09465a --- /dev/null +++ b/tests/devices/test_default_qutrit_mixed.py @@ -0,0 +1,108 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for default qutrit mixed.""" +import pytest +import numpy as np + +import pennylane as qml +from pennylane.devices.default_qutrit_mixed import DefaultQutritMixed + + +def test_name(): + """Tests the name of DefaultQutritMixed.""" + assert DefaultQutritMixed().name == "default.qutrit.mixed" + + +def test_debugger_attribute(): + """Test that DefaultQutritMixed has a debugger attribute and that it is `None`""" + # pylint: disable=protected-access + dev = DefaultQutritMixed() + + assert hasattr(dev, "_debugger") + assert dev._debugger is None + + +class TestRandomSeed: + """Test that the device behaves correctly when provided with a random seed""" + + def test_global_seed_no_device_seed_by_default(self): + """Test that the global numpy seed initializes the rng if device seed is None.""" + np.random.seed(42) + dev = DefaultQutritMixed() + first_num = dev._rng.random() # pylint: disable=protected-access + + np.random.seed(42) + dev2 = DefaultQutritMixed() + second_num = dev2._rng.random() # pylint: disable=protected-access + + assert qml.math.allclose(first_num, second_num) + + np.random.seed(42) + dev2 = DefaultQutritMixed(seed="global") + third_num = dev2._rng.random() # pylint: disable=protected-access + + assert qml.math.allclose(third_num, first_num) + + def test_none_seed_not_using_global_rng(self): + """Test that if the seed is None, it is uncorrelated with the global rng.""" + np.random.seed(42) + dev = DefaultQutritMixed(seed=None) + first_nums = dev._rng.random(10) # pylint: disable=protected-access + + np.random.seed(42) + dev2 = DefaultQutritMixed(seed=None) + second_nums = dev2._rng.random(10) # pylint: disable=protected-access + + assert not qml.math.allclose(first_nums, second_nums) + + def test_rng_as_seed(self): + """Test that a PRNG can be passed as a seed.""" + rng1 = np.random.default_rng(42) + first_num = rng1.random() + + rng = np.random.default_rng(42) + dev = DefaultQutritMixed(seed=rng) + second_num = dev._rng.random() # pylint: disable=protected-access + + assert qml.math.allclose(first_num, second_num) + + +@pytest.mark.jax +class TestPRNGKeySeed: + """Test that the device behaves correctly when provided with a PRNG key and using the JAX interface""" + + # pylint: disable=too-few-public-methods + + def test_prng_key_as_seed(self): + """Test that a jax PRNG can be passed as a seed.""" + from jax.config import config + + config.update("jax_enable_x64", True) + + from jax import random + + key1 = random.key(123) + first_nums = random.uniform(key1, shape=(10,)) + + key = random.key(123) + dev = DefaultQutritMixed(seed=key) + + second_nums = random.uniform(dev._prng_key, shape=(10,)) # pylint: disable=protected-access + assert np.all(first_nums == second_nums) + + +def test_execute_stump(): + """Tests the stump for execute returns None, test is for Codecov.""" + dev = DefaultQutritMixed() + assert dev.execute(None, None) is None