From a0b1f17735c64a049b86201d1f5e221e6d11b6a5 Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Thu, 27 Jun 2024 16:53:42 -0700 Subject: [PATCH 01/26] Need to push quick --- doing_a_test.py | 117 ++++++++++++++++ pennylane/devices/qtcorgi_helper/__init__.py | 0 .../qtcorgi_helper/apply_operations.py | 110 +++++++++++++++ .../qtcorgi_helper/qtcorgi_simulator.py | 125 ++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 doing_a_test.py create mode 100644 pennylane/devices/qtcorgi_helper/__init__.py create mode 100644 pennylane/devices/qtcorgi_helper/apply_operations.py create mode 100644 pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py diff --git a/doing_a_test.py b/doing_a_test.py new file mode 100644 index 00000000000..ec60c54ce61 --- /dev/null +++ b/doing_a_test.py @@ -0,0 +1,117 @@ +# import jax +import pennylane as qml +import numpy as np +# import scipy +# +# n_wires = 2 +# num_qscripts = 2 +# qscripts = [] +# for i in range(num_qscripts): +# unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() +# op = qml.QutritUnitary(unitary, wires=range(n_wires)) +# qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) +# qscripts.append(qs) +# +# dev = qml.devices.DefaultQutritMixed() +# program, execution_config = dev.preprocess() +# new_batch, post_processing_fn = program(qscripts) +# results = dev.execute(new_batch, execution_config=execution_config) +# print(post_processing_fn(results)) +# +# @jax.jit +# def f(x): +# qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) +# program, execution_config = dev.preprocess() +# new_batch, post_processing_fn = program([qs]) +# results = dev.execute(new_batch, execution_config=execution_config) +# return post_processing_fn(results)[0] +# +# jax.config.update("jax_enable_x64", True) +# print(f(jax.numpy.array(1.2))) +# print(jax.grad(f)(jax.numpy.array(1.2))) + +# 1/2, 1/3, 1/6 + +# 1/2, 1/3, 1/6 +# +# 1/2+1/3-2/6 +# +# 2/6-3/6 +# +# +# gellMann_8_coeffs = np.array([1/np.sqrt(3), -1, 1/np.sqrt(3), 1, 1/np.sqrt(3), -1]) / 2 +# gellMann_8_obs = [qml.GellMann(0, i) for i in [1, 2, 4, 5, 6, 7]] +# H1 = qml.Hamiltonian(gellMann_8_coeffs, gellMann_8_obs).matrix() +# +# G8 = qml.GellMann.compute_matrix(8) +# Had = qml.THadamard.compute_matrix() +# aHad = np.conj(Had).T +# +# H2 = np.round(aHad@G8@Had, 5) +# +# +# # print(np.allclose(H1, H2)) +# # print(H1) +# # print(H2) +# # print(np.round(H2*np.sqrt(3), 5)) +# obs = aHad@G8@Had +# +# diag_gates = qml.THermitian(obs, 0).diagonalizing_gates()#[0].matrix() +# print(len(diag_gates)) +# diag_gates = diag_gates[0].matrix() +# +# print(np.round((np.conj(diag_gates).T)@obs@diag_gates, 5)) + + +def setup_state(nr_wires): + """Sets up a basic state used for testing.""" + setup_unitary = np.array( + [ + [1 / np.sqrt(2), 1 / np.sqrt(3), 1 / np.sqrt(6)], + [np.sqrt(2 / 29), np.sqrt(3 / 29), -2 * np.sqrt(6 / 29)], + [-5 / np.sqrt(58), 7 / np.sqrt(87), 1 / np.sqrt(174)], + ] + ).T + qml.QutritUnitary(setup_unitary, wires=0) + qml.QutritUnitary(setup_unitary, wires=1) + if nr_wires == 3: + qml.TAdd(wires=(0, 2)) + + +dev = qml.device( + "default.qutrit.mixed", + wires=2, + damping_measurement_gammas=(0.2, 0.1, 0.4), + trit_flip_measurement_probs=(0.1, 0.2, 0.5), + ) +# Create matricies for the observables with diagonalizing matrix :math:`THadamard^\dag` +inv_sqrt_3_i = 1j / np.sqrt(3) +non_commuting_obs_one = np.array( + [ + [0, -1 + inv_sqrt_3_i, -1 - inv_sqrt_3_i], + [-1 - inv_sqrt_3_i, 0, -1 + inv_sqrt_3_i], + [-1 + inv_sqrt_3_i, -1 - inv_sqrt_3_i, 0], + ] +) +non_commuting_obs_one /= 2 + +@qml.qnode(dev) +def circuit(): + setup_state(2) + + qml.THadamard(wires=0) + qml.THadamard(wires=1, subspace=(0, 1)) + + return qml.expval(qml.THermitian(non_commuting_obs_one, 0)) + + + + + +print(my_test()) + + + + + + diff --git a/pennylane/devices/qtcorgi_helper/__init__.py b/pennylane/devices/qtcorgi_helper/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py new file mode 100644 index 00000000000..c115e4ba54a --- /dev/null +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -0,0 +1,110 @@ +import time +import jax +import jax.numpy as jnp +from jax.lax import scan +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platforms", "cpu") +import pennylane as qml +from string import ascii_letters as alphabet + +import numpy as np +alphabet_array = np.array(list(alphabet)) + +def get_einsum_mapping(wires, state): + r"""Finds the indices for einsum to apply kraus operators to a mixed state + + Args: + wires + state (array[complex]): Input quantum state + + Returns: + str: Indices mapping that defines the einsum + """ + num_ch_wires = len(wires) + num_wires = int(len(qml.math.shape(state)) / 2) + rho_dim = 2 * num_wires + + # Tensor indices of the state. For each qutrit, need an index for rows *and* columns + state_indices = alphabet[:rho_dim] + + # row indices of the quantum state affected by this operation + row_wires_list = wires.tolist() + row_indices = "".join(alphabet_array[row_wires_list].tolist()) + + # column indices are shifted by the number of wires + col_wires_list = [w + num_wires for w in row_wires_list] + col_indices = "".join(alphabet_array[col_wires_list].tolist()) + + # indices in einsum must be replaced with new ones + new_row_indices = alphabet[rho_dim : rho_dim + num_ch_wires] + new_col_indices = alphabet[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires] + + # index for summation over Kraus operators + kraus_index = alphabet[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1] + + # apply mapping function + op_1_indices = f"{kraus_index}{new_row_indices}{row_indices}" + op_2_indices = f"{kraus_index}{col_indices}{new_col_indices}" + + new_state_indices = get_new_state_einsum_indices( + old_indices=col_indices + row_indices, + new_indices=new_col_indices + new_row_indices, + state_indices=state_indices, + ) + # index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef' + return ( + f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}" + ) + +def get_new_state_einsum_indices(old_indices, new_indices, state_indices): + """Retrieves the einsum indices string for the new state + + Args: + old_indices (str): indices that are summed + new_indices (str): indices that must be replaced with sums + state_indices (str): indices of the original state + + Returns: + str: The einsum indices of the new state + """ + return functools.reduce( + lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), + zip(old_indices, new_indices), + state_indices, + ) + +QUDIT_DIM = 3 +def apply_operation_einsum(kraus, wires, state): + r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the + quantum state. For a unitary gate, there is a single Kraus operator. + + Args: + kraus (??): TODO + state (array[complex]): Input quantum state + + Returns: + array[complex]: output_state + """ + einsum_indices = get_einsum_mapping(wires, state) + + num_ch_wires = len(wires) + + # Shape kraus operators + kraus_shape = [len(kraus)] + [QUDIT_DIM] * num_ch_wires * 2 + + kraus = jnp.stack(kraus) + kraus_transpose = jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2)) + # Torch throws error if math.conj is used before stack + kraus_dagger = jnp.conj(kraus_transpose) + + kraus = jnp.reshape(kraus, kraus_shape) + kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) + + return jnp.einsum(einsum_indices, kraus, state, kraus_dagger) + + +def f(carry, x): + if x[0] == 0: + qml.TRX().matrix + +flag, wires, inputs \ No newline at end of file diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py new file mode 100644 index 00000000000..237be5412bc --- /dev/null +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -0,0 +1,125 @@ + +def simulate( + circuit: qml.tape.QuantumScript, rng=None, prng_key=None, debugger=None, interface=None +) -> Result: + """TODO + + Args: + circuit (QuantumTape): The single circuit to simulate + rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A + seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. + If no value is provided, a default RNG will be used. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. If None, a random key will be + generated. Only for simulation using JAX. + debugger (_Debugger): The debugger to use + interface (str): The machine learning interface to create the initial state with + + Returns: + tuple(TensorLike): The results of the simulation + + Note that this function can return measurements for non-commuting observables simultaneously. + + This function assumes that all operations provide matrices. + + >>> qs = qml.tape.QuantumScript([qml.TRX(1.2, wires=0)], [qml.expval(qml.GellMann(0, 3)), qml.probs(wires=(0,1))]) + >>> simulate(qs) + (0.36235775447667357, + tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) + + """ + state, is_state_batched = get_final_state( + circuit, debugger=debugger, interface=interface, rng=rng, prng_key=prng_key + ) + return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key) + + +def get_final_state(circuit, debugger=None, interface=None, **kwargs): + """ + TODO + + Args: + circuit (.QuantumScript): The single circuit to simulate + debugger (._Debugger): The debugger to use + interface (str): The machine learning interface to create the initial state with + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + circuit = circuit.map_to_standard_wires() + state = create_initial_state(sorted(circuit.op_wires), like="jax") + + op_array = [] + for op in circuit.operations: + + # op_index = None + # for i, op_type in enumerate(op_types): + # if isinstance(op, op_type): + # op_index = i + # if op_index is None: + # raise ValueError("This simulator only supports") + + op_index = op_types.index(type(op)) + + wires = op.wires() + if len(wires) == 1: + wires = [-1, wires[0]] + if len(wires) == 2: + wires = list(wires) + params = op.parameters + ([0] * (3-op.num_params)) + op_array.append([[op_index] + wires, params]) + + return state + + +def get_final_state(circuit, debugger=None, interface=None, **kwargs): + """ + TODO + + Args: + circuit (.QuantumScript): The single circuit to simulate + debugger (._Debugger): The debugger to use + interface (str): The machine learning interface to create the initial state with + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + circuit = circuit.map_to_standard_wires() + + prep = None + if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): + prep = circuit[0] + + state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface]) + + # initial state is batched only if the state preparation (if it exists) is batched + is_state_batched = bool(prep and prep.batch_size is not None) + for op in circuit.operations[bool(prep) :]: + state = apply_operation( + op, + state, + is_state_batched=is_state_batched, + debugger=debugger, + tape_shots=circuit.shots, + **kwargs, + ) + + # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim + is_state_batched = is_state_batched or op.batch_size is not None + + num_operated_wires = len(circuit.op_wires) + for i in range(len(circuit.wires) - num_operated_wires): + # If any measured wires are not operated on, we pad the density matrix with zeros. + # We know they belong at the end because the circuit is in standard wire-order + # Since it is a dm, we must pad it with 0s on the last row and last column + current_axis = num_operated_wires + i + is_state_batched + state = qml.math.stack( + ([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=current_axis + ) + state = qml.math.stack(([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=-1) + + return state, is_state_batched \ No newline at end of file From 89d9e4e7b8b256044adf51224e488780604a87e8 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Sat, 29 Jun 2024 21:03:42 -0700 Subject: [PATCH 02/26] Worked on getting simulate setup, need to do apply gates --- my_testing.py | 97 +++ pennylane/devices/jittable_mixed.py | 797 ++++++++++++++++++ pennylane/devices/jittable_qutrit_mixed.py | 311 +++++++ .../qtcorgi_helper/apply_operations.py | 8 +- .../qtcorgi_helper/qtcorgi_simulator.py | 150 ++-- 5 files changed, 1281 insertions(+), 82 deletions(-) create mode 100644 my_testing.py create mode 100644 pennylane/devices/jittable_mixed.py create mode 100644 pennylane/devices/jittable_qutrit_mixed.py diff --git a/my_testing.py b/my_testing.py new file mode 100644 index 00000000000..611cf1e626e --- /dev/null +++ b/my_testing.py @@ -0,0 +1,97 @@ +import numpy as np +import pennylane as qml +# +# U = np.array([[1/np.sqrt(2), 1/np.sqrt(3), 1/np.sqrt(6)], [np.sqrt(2/29), np.sqrt(3/29), -2 * np.sqrt(6/29)], [-5/np.sqrt(58), 7/np.sqrt(87), 1/np.sqrt(174)]]) +# +# +# print(np.linalg.norm(U[1])) + +# inv_sqrt_3 = 1 / np.sqrt(3) +# gellMann_8_coeffs = np.array([inv_sqrt_3, -1, inv_sqrt_3, 1, inv_sqrt_3, -1]) / 2 +# gellMann_8_obs = [qml.GellMann(0, i) for i in [1, 2, 4, 5, 6, 7]] + +# obs = qml.expval(qml.Hamiltonian(gellMann_8_coeffs, gellMann_8_obs)) +# +# obs.matrix() +# +# ham = qml.Hamiltonian() + +# from scipy.stats import unitary_group +# X = qml.PauliX.compute_matrix() +# Y = qml.PauliX.compute_matrix() +# Z = qml.PauliX.compute_matrix() +# I = np.eye(2) + + + +#gellMann_8_obs = [qml.Hermitian(sum([np.random.rand()*m for m in [X, Y, Z, I]]), wires=0)] +# for obs in gellMann_8_obs: +# diagm = obs.diagonalizing_gates()[0].matrix() +# diagm_adj = np.conj(diagm).T +# obsm = obs.matrix() +# +# print(np.round(diagm@obsm@diagm_adj, 5), "\n") +# print(np.round(diagm_adj @ diagm, 5), "\n===============================================================\n") +had = qml.THadamard.compute_matrix() +ahad = np.conj(had).T +G8 = qml.GellMann.compute_matrix(8) +G3 = qml.GellMann.compute_matrix(3) + + +# print(np.round(had@G8@ahad, 5)) +# +# print() +# print(np.round(had@G3@ahad, 5)) + +inv_sqrt_3 = 1 / np.sqrt(3) +inv_sqrt_3_i = inv_sqrt_3 * 1j +# +gellMann_3_equivalent = ( + np.array( + [[0, 1+inv_sqrt_3_i, 1-inv_sqrt_3_i], + [1-inv_sqrt_3_i, 0, 1 + inv_sqrt_3_i], + [1+inv_sqrt_3_i, 1-inv_sqrt_3_i, 0]] + ) + / 2 + ) +gellMann_8_equivalent = ( + np.array( + [[0, (inv_sqrt_3 - 1j), (inv_sqrt_3 + 1j)], + [inv_sqrt_3 + 1j, 0, inv_sqrt_3 - 1j], + [inv_sqrt_3 - 1j, inv_sqrt_3 + 1j, 0]] + ) + / 2 + ) + +dg = qml.THermitian(gellMann_8_equivalent, 0).diagonalizing_gates()[0].matrix() +# dga = np.conj(dg).T +# print(np.round(dg@gellMann_8_equivalent@dga, 5)) +# +# print(np.abs((dg@had@(np.array([1/2,1/3,1/6])**(1/2))))**2) +#print(qml.GellMann(0, 1).diagonalizing_gates()[0].matrix()) + +# print(np.round(dg@had, 4)) +obs = np.diag([1, 2, 3]) +print(np.round(had@obs@ahad, 4)) + +obs = np.diag([-2, -1, 1]) + +non_commuting_obs_two = np.array( + [ + [-2/3, -2/3 + inv_sqrt_3_i, -2/3 - inv_sqrt_3_i], + [-2/3 - inv_sqrt_3_i, -2/3, -2/3 + inv_sqrt_3_i], + [-2/3 + inv_sqrt_3_i, -2/3 - inv_sqrt_3_i, -2/3], + ] + ) + +print(np.round(had@obs@ahad, 4)) + +print(np.allclose(had@obs@ahad, non_commuting_obs_two)) +# print(np.allclose(had@G8@ahad, gellMann_8_equivalent)) +# #print(had@G8@ahad) +# print(np.allclose(ahad, dg)) +import jax +print(jax.numpy.array([jax.numpy.nan, 1., 2.])) + + + diff --git a/pennylane/devices/jittable_mixed.py b/pennylane/devices/jittable_mixed.py new file mode 100644 index 00000000000..0e2d0fbfcf0 --- /dev/null +++ b/pennylane/devices/jittable_mixed.py @@ -0,0 +1,797 @@ +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +The default.mixed device is PennyLane's standard qubit simulator for mixed-state computations. + +It implements the necessary :class:`~pennylane.Device` methods as well as some built-in +qubit :doc:`operations `, providing a simple mixed-state simulation of +qubit-based quantum circuits. +""" + +import functools +import itertools +import logging +from collections import defaultdict +from string import ascii_letters as ABC + +import numpy as np + +import pennylane as qml +import pennylane.math as qnp +from pennylane import BasisState, DeviceError, QubitDensityMatrix, QubitDevice, Snapshot, StatePrep +from pennylane.logging import debug_logger, debug_logger_init +from pennylane.measurements import ( + CountsMP, + DensityMatrixMP, + ExpectationMP, + MutualInfoMP, + ProbabilityMP, + PurityMP, + SampleMP, + StateMP, + VarianceMP, + VnEntropyMP, +) +from pennylane.operation import Channel +from pennylane.ops.qubit.attributes import diagonal_in_z_basis +from pennylane.wires import Wires + +from .._version import __version__ + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +ABC_ARRAY = np.array(list(ABC)) +tolerance = 1e-10 + + +class DefaultMixed(QubitDevice): + """Default qubit device for performing mixed-state computations in PennyLane. + + .. warning:: + + The API of ``DefaultMixed`` will be updated soon to follow a new device interface described + in :class:`pennylane.devices.Device`. + + This change will not alter device behaviour for most workflows, but may have implications for + plugin developers and users who directly interact with device methods. Please consult + :class:`pennylane.devices.Device` and the implementation in + :class:`pennylane.devices.DefaultQubit` for more information on what the new + interface will look like and be prepared to make updates in a coming release. If you have any + feedback on these changes, please create an + `issue `_ or post in our + `discussion forum `_. + + Args: + wires (int, Iterable[Number, str]): Number of subsystems represented by the device, + or iterable that contains unique labels for the subsystems as numbers + (i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``). + shots (None, int): Number of times the circuit should be evaluated (or sampled) to estimate + the expectation values. Defaults to ``None`` if not specified, which means that + outputs are computed exactly. + readout_prob (None, int, float): Probability for adding readout error to the measurement + outcomes of observables. Defaults to ``None`` if not specified, which means that the outcomes are + without any readout error. + """ + + name = "Default mixed-state qubit PennyLane plugin" + short_name = "default.mixed" + pennylane_requires = __version__ + version = __version__ + author = "Xanadu Inc." + + operations = { + "Identity", + "Snapshot", + "BasisState", + "QubitStateVector", + "StatePrep", + "QubitDensityMatrix", + "QubitUnitary", + "ControlledQubitUnitary", + "BlockEncode", + "MultiControlledX", + "DiagonalQubitUnitary", + "SpecialUnitary", + "PauliX", + "PauliY", + "PauliZ", + "MultiRZ", + "Hadamard", + "S", + "T", + "SX", + "CNOT", + "SWAP", + "ISWAP", + "CSWAP", + "Toffoli", + "CCZ", + "CY", + "CZ", + "CH", + "PhaseShift", + "PCPhase", + "ControlledPhaseShift", + "CPhaseShift00", + "CPhaseShift01", + "CPhaseShift10", + "RX", + "RY", + "RZ", + "Rot", + "CRX", + "CRY", + "CRZ", + "CRot", + "AmplitudeDamping", + "GeneralizedAmplitudeDamping", + "PhaseDamping", + "DepolarizingChannel", + "BitFlip", + "PhaseFlip", + "PauliError", + "ResetError", + "QubitChannel", + "SingleExcitation", + "SingleExcitationPlus", + "SingleExcitationMinus", + "DoubleExcitation", + "DoubleExcitationPlus", + "DoubleExcitationMinus", + "QubitCarry", + "QubitSum", + "OrbitalRotation", + "FermionicSWAP", + "QFT", + "ThermalRelaxationError", + "ECR", + "ParametrizedEvolution", + "GlobalPhase", + } + + _reshape = staticmethod(qnp.reshape) + _flatten = staticmethod(qnp.flatten) + _transpose = staticmethod(qnp.transpose) + # Allow for the `axis` keyword argument for integration with broadcasting-enabling + # code in QubitDevice. However, it is not used as DefaultMixed does not support broadcasting + # pylint: disable=unnecessary-lambda + _gather = staticmethod(lambda *args, axis=0, **kwargs: qnp.gather(*args, **kwargs)) + _dot = staticmethod(qnp.dot) + + measurement_map = defaultdict(lambda: "") + measurement_map[PurityMP] = "purity" + + @staticmethod + def _reduce_sum(array, axes): + return qnp.sum(array, tuple(axes)) + + @staticmethod + def _asarray(array, dtype=None): + # Support float + if not hasattr(array, "__len__"): + return np.asarray(array, dtype=dtype) + + res = qnp.cast(array, dtype=dtype) + return res + + @debug_logger_init + def __init__( + self, + wires, + *, + r_dtype=np.float64, + c_dtype=np.complex128, + shots=None, + analytic=None, + readout_prob=None, + ): + if isinstance(wires, int) and wires > 23: + raise ValueError( + "This device does not currently support computations on more than 23 wires" + ) + + self.readout_err = readout_prob + # Check that the readout error probability, if entered, is either integer or float in [0,1] + if self.readout_err is not None: + if not isinstance(self.readout_err, float) and not isinstance(self.readout_err, int): + raise TypeError( + "The readout error probability should be an integer or a floating-point number in [0,1]." + ) + if self.readout_err < 0 or self.readout_err > 1: + raise ValueError("The readout error probability should be in the range [0,1].") + + # call QubitDevice init + super().__init__(wires, shots, r_dtype=r_dtype, c_dtype=c_dtype, analytic=analytic) + self._debugger = None + + # Create the initial state. + self._state = self._create_basis_state(0) + self._pre_rotated_state = self._state + self.measured_wires = [] + """List: during execution, stores the list of wires on which measurements are acted for + applying the readout error to them when readout_prob is non-zero.""" + + def _create_basis_state(self, index): + """Return the density matrix representing a computational basis state over all wires. + + Args: + index (int): integer representing the computational basis state. + + Returns: + array[complex]: complex array of shape ``[2] * (2 * num_wires)`` + representing the density matrix of the basis state. + """ + rho = qnp.zeros((2**self.num_wires, 2**self.num_wires), dtype=self.C_DTYPE) + rho[index, index] = 1 + return qnp.reshape(rho, [2] * (2 * self.num_wires)) + + @classmethod + def capabilities(cls): + capabilities = super().capabilities().copy() + capabilities.update( + returns_state=True, + passthru_devices={ + "autograd": "default.mixed", + "tf": "default.mixed", + "torch": "default.mixed", + "jax": "default.mixed", + }, + ) + return capabilities + + @property + def state(self): + """Returns the state density matrix of the circuit prior to measurement""" + dim = 2**self.num_wires + # User obtains state as a matrix + return qnp.reshape(self._pre_rotated_state, (dim, dim)) + + @debug_logger + def density_matrix(self, wires): + """Returns the reduced density matrix over the given wires. + + Args: + wires (Wires): wires of the reduced system + + Returns: + array[complex]: complex array of shape ``(2 ** len(wires), 2 ** len(wires))`` + representing the reduced density matrix of the state prior to measurement. + """ + state = getattr(self, "state", None) + wires = self.map_wires(wires) + return qml.math.reduce_dm(state, indices=wires, c_dtype=self.C_DTYPE) + + @debug_logger + def purity(self, mp, **kwargs): # pylint: disable=unused-argument + """Returns the purity of the final state""" + state = getattr(self, "state", None) + wires = self.map_wires(mp.wires) + return qml.math.purity(state, indices=wires, c_dtype=self.C_DTYPE) + + @debug_logger + def reset(self): + """Resets the device""" + super().reset() + + self._state = self._create_basis_state(0) + self._pre_rotated_state = self._state + + @debug_logger + def analytic_probability(self, wires=None): + if self._state is None: + return None + + # convert rho from tensor to matrix + rho = qnp.reshape(self._state, (2**self.num_wires, 2**self.num_wires)) + + # probs are diagonal elements + probs = self.marginal_prob(qnp.diagonal(rho), wires) + + # take the real part so probabilities are not shown as complex numbers + probs = qnp.real(probs) + return qnp.where(probs < 0, -probs, probs) + + def _get_kraus(self, operation): # pylint: disable=no-self-use + """Return the Kraus operators representing the operation. + + Args: + operation (.Operation): a PennyLane operation + + Returns: + list[array[complex]]: Returns a list of 2D matrices representing the Kraus operators. If + the operation is unitary, returns a single Kraus operator. In the case of a diagonal + unitary, returns a 1D array representing the matrix diagonal. + """ + if operation in diagonal_in_z_basis: + return operation.eigvals() + + if isinstance(operation, Channel): + return operation.kraus_matrices() + + return [operation.matrix()] + + def _apply_channel(self, kraus, wires): + r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the + quantum state. For a unitary gate, there is a single Kraus operator. + + Args: + kraus (list[array]): Kraus operators + wires (Wires): target wires + """ + channel_wires = self.map_wires(wires) + rho_dim = 2 * self.num_wires + num_ch_wires = len(channel_wires) + + # Computes K^\dagger, needed for the transformation K \rho K^\dagger + kraus_dagger = [qnp.conj(qnp.transpose(k)) for k in kraus] + + kraus = qnp.stack(kraus) + kraus_dagger = qnp.stack(kraus_dagger) + + # Shape kraus operators + kraus_shape = [len(kraus)] + [2] * num_ch_wires * 2 + kraus = qnp.cast(qnp.reshape(kraus, kraus_shape), dtype=self.C_DTYPE) + kraus_dagger = qnp.cast(qnp.reshape(kraus_dagger, kraus_shape), dtype=self.C_DTYPE) + + # Tensor indices of the state. For each qubit, need an index for rows *and* columns + state_indices = ABC[:rho_dim] + + # row indices of the quantum state affected by this operation + row_wires_list = channel_wires.tolist() + row_indices = "".join(ABC_ARRAY[row_wires_list].tolist()) + + # column indices are shifted by the number of wires + col_wires_list = [w + self.num_wires for w in row_wires_list] + col_indices = "".join(ABC_ARRAY[col_wires_list].tolist()) + + # indices in einsum must be replaced with new ones + new_row_indices = ABC[rho_dim : rho_dim + num_ch_wires] + new_col_indices = ABC[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires] + + # index for summation over Kraus operators + kraus_index = ABC[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1] + + # new state indices replace row and column indices with new ones + new_state_indices = functools.reduce( + lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), + zip(col_indices + row_indices, new_col_indices + new_row_indices), + state_indices, + ) + + # index mapping for einsum, e.g., 'iga,abcdef,idh->gbchef' + einsum_indices = ( + f"{kraus_index}{new_row_indices}{row_indices}, {state_indices}," + f"{kraus_index}{col_indices}{new_col_indices}->{new_state_indices}" + ) + + self._state = qnp.einsum(einsum_indices, kraus, self._state, kraus_dagger) + + def _apply_channel_tensordot(self, kraus, wires): + r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the + quantum state. For a unitary gate, there is a single Kraus operator. + + Args: + kraus (list[array]): Kraus operators + wires (Wires): target wires + """ + channel_wires = self.map_wires(wires) + num_ch_wires = len(channel_wires) + + # Shape kraus operators and cast them to complex data type + kraus_shape = [2] * (num_ch_wires * 2) + kraus = [qnp.cast(qnp.reshape(k, kraus_shape), dtype=self.C_DTYPE) for k in kraus] + + # row indices of the quantum state affected by this operation + row_wires_list = channel_wires.tolist() + # column indices are shifted by the number of wires + col_wires_list = [w + self.num_wires for w in row_wires_list] + + channel_col_ids = list(range(num_ch_wires, 2 * num_ch_wires)) + axes_left = [channel_col_ids, row_wires_list] + # Use column indices instead or rows to incorporate transposition of K^\dagger + axes_right = [col_wires_list, channel_col_ids] + + # Apply the Kraus operators, and sum over all Kraus operators afterwards + def _conjugate_state_with(k): + """Perform the double tensor product k @ self._state @ k.conj(). + The `axes_left` and `axes_right` arguments are taken from the ambient variable space + and `axes_right` is assumed to incorporate the tensor product and the transposition + of k.conj() simultaneously.""" + return qnp.tensordot(qnp.tensordot(k, self._state, axes_left), qnp.conj(k), axes_right) + + if len(kraus) == 1: + _state = _conjugate_state_with(kraus[0]) + else: + _state = qnp.sum(qnp.stack([_conjugate_state_with(k) for k in kraus]), axis=0) + + # Permute the affected axes to their destination places. + # The row indices of the kraus operators are moved from the beginning to the original + # target row locations, the column indices from the end to the target column locations + source_left = list(range(num_ch_wires)) + dest_left = row_wires_list + source_right = list(range(-num_ch_wires, 0)) + dest_right = col_wires_list + self._state = qnp.moveaxis(_state, source_left + source_right, dest_left + dest_right) + + def _apply_diagonal_unitary(self, eigvals, wires): + r"""Apply a diagonal unitary gate specified by a list of eigenvalues. This method uses + the fact that the unitary is diagonal for a more efficient implementation. + + Args: + eigvals (array): eigenvalues (phases) of the diagonal unitary + wires (Wires): target wires + """ + + channel_wires = self.map_wires(wires) + + eigvals = qnp.stack(eigvals) + + # reshape vectors + eigvals = qnp.cast(qnp.reshape(eigvals, [2] * len(channel_wires)), dtype=self.C_DTYPE) + + # Tensor indices of the state. For each qubit, need an index for rows *and* columns + state_indices = ABC[: 2 * self.num_wires] + + # row indices of the quantum state affected by this operation + row_wires_list = channel_wires.tolist() + row_indices = "".join(ABC_ARRAY[row_wires_list].tolist()) + + # column indices are shifted by the number of wires + col_wires_list = [w + self.num_wires for w in row_wires_list] + col_indices = "".join(ABC_ARRAY[col_wires_list].tolist()) + + einsum_indices = f"{row_indices},{state_indices},{col_indices}->{state_indices}" + + self._state = qnp.einsum(einsum_indices, eigvals, self._state, qnp.conj(eigvals)) + + def _apply_basis_state(self, state, wires): + """Initialize the device in a specified computational basis state. + + Args: + state (array[int]): computational basis state of shape ``(wires,)`` + consisting of 0s and 1s. + wires (Wires): wires that the provided computational state should be initialized on + """ + # translate to wire labels used by device + device_wires = self.map_wires(wires) + + # length of basis state parameter + n_basis_state = len(state) + + if not set(state).issubset({0, 1}): + raise ValueError("BasisState parameter must consist of 0 or 1 integers.") + + if n_basis_state != len(device_wires): + raise ValueError("BasisState parameter and wires must be of equal length.") + + # get computational basis state number + basis_states = 2 ** (self.num_wires - 1 - device_wires.toarray()) + num = int(qnp.dot(state, basis_states)) + + self._state = self._create_basis_state(num) + + def _apply_state_vector(self, state, device_wires): + """Initialize the internal state in a specified pure state. + + Args: + state (array[complex]): normalized input state of length + ``2**len(wires)`` + device_wires (Wires): wires that get initialized in the state + """ + + # translate to wire labels used by device + device_wires = self.map_wires(device_wires) + + state = qnp.asarray(state, dtype=self.C_DTYPE) + n_state_vector = state.shape[0] + + if state.ndim != 1 or n_state_vector != 2 ** len(device_wires): + raise ValueError("State vector must be of length 2**wires.") + + if not qnp.allclose(qnp.linalg.norm(state, ord=2), 1.0, atol=tolerance): + raise ValueError("Sum of amplitudes-squared does not equal one.") + + if len(device_wires) == self.num_wires and sorted(device_wires.labels) == list( + device_wires.labels + ): + # Initialize the entire wires with the state + rho = qnp.outer(state, qnp.conj(state)) + self._state = qnp.reshape(rho, [2] * 2 * self.num_wires) + + else: + # generate basis states on subset of qubits via the cartesian product + basis_states = qnp.asarray( + list(itertools.product([0, 1], repeat=len(device_wires))), dtype=int + ) + + # get basis states to alter on full set of qubits + unravelled_indices = qnp.zeros((2 ** len(device_wires), self.num_wires), dtype=int) + unravelled_indices[:, device_wires] = basis_states + + # get indices for which the state is changed to input state vector elements + ravelled_indices = qnp.ravel_multi_index(unravelled_indices.T, [2] * self.num_wires) + + state = qnp.scatter(ravelled_indices, state, [2**self.num_wires]) + rho = qnp.outer(state, qnp.conj(state)) + rho = qnp.reshape(rho, [2] * 2 * self.num_wires) + self._state = qnp.asarray(rho, dtype=self.C_DTYPE) + + def _apply_density_matrix(self, state, device_wires): + r"""Initialize the internal state in a specified mixed state. + If not all the wires are specified in the full state :math:`\rho`, remaining subsystem is filled by + `\mathrm{tr}_in(\rho)`, which results in the full system state :math:`\mathrm{tr}_{in}(\rho) \otimes \rho_{in}`, + where :math:`\rho_{in}` is the argument `state` of this function and :math:`\mathrm{tr}_{in}` is a partial + trace over the subsystem to be replaced by this operation. + + Args: + state (array[complex]): density matrix of length + ``(2**len(wires), 2**len(wires))`` + device_wires (Wires): wires that get initialized in the state + """ + + # translate to wire labels used by device + device_wires = self.map_wires(device_wires) + + state = qnp.asarray(state, dtype=self.C_DTYPE) + state = qnp.reshape(state, (-1,)) + + state_dim = 2 ** len(device_wires) + dm_dim = state_dim**2 + if dm_dim != state.shape[0]: + raise ValueError("Density matrix must be of length (2**wires, 2**wires)") + + if not qml.math.is_abstract(state) and not qnp.allclose( + qnp.trace(qnp.reshape(state, (state_dim, state_dim))), 1.0, atol=tolerance + ): + raise ValueError("Trace of density matrix is not equal one.") + + if len(device_wires) == self.num_wires and sorted(device_wires.labels) == list( + device_wires.labels + ): + # Initialize the entire wires with the state + + self._state = qnp.reshape(state, [2] * 2 * self.num_wires) + self._pre_rotated_state = self._state + + else: + # Initialize tr_in(ρ) ⊗ ρ_in with transposed wires where ρ is the density matrix before this operation. + + complement_wires = list(sorted(list(set(range(self.num_wires)) - set(device_wires)))) + sigma = self.density_matrix(Wires(complement_wires)) + rho = qnp.kron(sigma, state.reshape(state_dim, state_dim)) + rho = rho.reshape([2] * 2 * self.num_wires) + + # Construct transposition axis to revert back to the original wire order + left_axes = [] + right_axes = [] + complement_wires_count = len(complement_wires) + for i in range(self.num_wires): + if i in device_wires: + index = device_wires.index(i) + left_axes.append(complement_wires_count + index) + right_axes.append(complement_wires_count + index + self.num_wires) + elif i in complement_wires: + index = complement_wires.index(i) + left_axes.append(index) + right_axes.append(index + self.num_wires) + transpose_axes = left_axes + right_axes + rho = qnp.transpose(rho, axes=transpose_axes) + assert qml.math.is_abstract(rho) or qnp.allclose( + qnp.trace(qnp.reshape(rho, (2**self.num_wires, 2**self.num_wires))), + 1.0, + atol=tolerance, + ) + + self._state = qnp.asarray(rho, dtype=self.C_DTYPE) + self._pre_rotated_state = self._state + + def _snapshot_measurements(self, density_matrix, measurement): + """Perform state-based snapshot measurement""" + meas_wires = self.wires if not measurement.wires else measurement.wires + + pre_rotated_state = self._state + if isinstance(measurement, (ProbabilityMP, ExpectationMP, VarianceMP)): + for diag_gate in measurement.diagonalizing_gates(): + self._apply_operation(diag_gate) + + if isinstance(measurement, (StateMP, DensityMatrixMP)): + map_wires = self.map_wires(meas_wires) + snap_result = qml.math.reduce_dm( + density_matrix, indices=map_wires, c_dtype=self.C_DTYPE + ) + + elif isinstance(measurement, PurityMP): + map_wires = self.map_wires(meas_wires) + snap_result = qml.math.purity(density_matrix, indices=map_wires, c_dtype=self.C_DTYPE) + + elif isinstance(measurement, ProbabilityMP): + snap_result = self.analytic_probability(wires=meas_wires) + + elif isinstance(measurement, ExpectationMP): + eigvals = self._asarray(measurement.obs.eigvals(), dtype=self.R_DTYPE) + probs = self.analytic_probability(wires=meas_wires) + snap_result = self._dot(probs, eigvals) + + elif isinstance(measurement, VarianceMP): + eigvals = self._asarray(measurement.obs.eigvals(), dtype=self.R_DTYPE) + probs = self.analytic_probability(wires=meas_wires) + snap_result = self._dot(probs, (eigvals**2)) - self._dot(probs, eigvals) ** 2 + + elif isinstance(measurement, VnEntropyMP): + base = measurement.log_base + map_wires = self.map_wires(meas_wires) + snap_result = qml.math.vn_entropy( + density_matrix, indices=map_wires, c_dtype=self.C_DTYPE, base=base + ) + + elif isinstance(measurement, MutualInfoMP): + base = measurement.log_base + wires0, wires1 = list(map(self.map_wires, measurement.raw_wires)) + snap_result = qml.math.mutual_info( + density_matrix, + indices0=wires0, + indices1=wires1, + c_dtype=self.C_DTYPE, + base=base, + ) + + else: + raise DeviceError( + f"Snapshots of {type(measurement)} are not yet supported on default.mixed" + ) + + self._state = pre_rotated_state + self._pre_rotated_state = self._state + + return snap_result + + def _apply_snapshot(self, operation): + """Applies the snapshot operation""" + measurement = operation.hyperparameters["measurement"] + + if self._debugger and self._debugger.active: + dim = 2**self.num_wires + density_matrix = qnp.reshape(self._state, (dim, dim)) + + snapshot_result = self._snapshot_measurements(density_matrix, measurement) + + if operation.tag: + self._debugger.snapshots[operation.tag] = snapshot_result + else: + self._debugger.snapshots[len(self._debugger.snapshots)] = snapshot_result + + def _apply_operation(self, operation): + """Applies operations to the internal device state. + + Args: + operation (.Operation): operation to apply on the device + """ + wires = operation.wires + if operation.name == "Identity": + return + + if isinstance(operation, StatePrep): + self._apply_state_vector(operation.parameters[0], wires) + return + + if isinstance(operation, BasisState): + self._apply_basis_state(operation.parameters[0], wires) + return + + if isinstance(operation, QubitDensityMatrix): + self._apply_density_matrix(operation.parameters[0], wires) + return + + if isinstance(operation, Snapshot): + self._apply_snapshot(operation) + return + + matrices = self._get_kraus(operation) + + if operation in diagonal_in_z_basis: + self._apply_diagonal_unitary(matrices, wires) + else: + num_op_wires = len(wires) + interface = qml.math.get_interface(self._state, *matrices) + # Use tensordot for Autograd and Numpy if there are more than 2 wires + # Use tensordot in any case for more than 7 wires, as einsum does not support this case + if (num_op_wires > 2 and interface in {"autograd", "numpy"}) or num_op_wires > 7: + self._apply_channel_tensordot(matrices, wires) + else: + self._apply_channel(matrices, wires) + + # pylint: disable=arguments-differ + + @debug_logger + def execute(self, circuit, **kwargs): + """Execute a queue of quantum operations on the device and then + measure the given observables. + + Applies a readout error to the measurement outcomes of any observable if + readout_prob is non-zero. This is done by finding the list of measured wires on which + BitFlip channels are applied in the :meth:`apply`. + + For plugin developers: instead of overwriting this, consider + implementing a suitable subset of + + * :meth:`apply` + + * :meth:`~.generate_samples` + + * :meth:`~.probability` + + Additional keyword arguments may be passed to this method + that can be utilised by :meth:`apply`. An example would be passing + the ``QNode`` hash that can be used later for parametric compilation. + + Args: + circuit (QuantumTape): circuit to execute on the device + + Raises: + QuantumFunctionError: if the value of :attr:`~.Observable.return_type` is not supported + + Returns: + array[float]: measured value(s) + """ + if self.readout_err: + wires_list = [] + for m in circuit.measurements: + if isinstance(m, StateMP): + # State: This returns pre-rotated state, so no readout error. + # Assumed to only be allowed if it's the only measurement. + self.measured_wires = [] + return super().execute(circuit, **kwargs) + if isinstance(m, (SampleMP, CountsMP)) and m.wires in ( + qml.wires.Wires([]), + self.wires, + ): + # Sample, Counts: Readout error applied to all device wires when wires + # not specified or all wires specified. + self.measured_wires = self.wires + return super().execute(circuit, **kwargs) + if isinstance(m, (VnEntropyMP, MutualInfoMP)): + # VnEntropy, MutualInfo: Computed for the state prior to measurement. So, readout + # error need not be applied on the corresponding device wires. + continue + wires_list.append(m.wires) + self.measured_wires = qml.wires.Wires.all_wires(wires_list) + return super().execute(circuit, **kwargs) + + @debug_logger + def apply(self, operations, rotations=None, **kwargs): + rotations = rotations or [] + + # apply the circuit operations + for i, operation in enumerate(operations): + if i > 0 and isinstance(operation, (StatePrep, BasisState)): + raise DeviceError( + f"Operation {operation.name} cannot be used after other Operations have already been applied " + f"on a {self.short_name} device." + ) + + for operation in operations: + self._apply_operation(operation) + + # store the pre-rotated state + self._pre_rotated_state = self._state + + # apply the circuit rotations + for operation in rotations: + self._apply_operation(operation) + + if self.readout_err: + for k in self.measured_wires: + bit_flip = qml.BitFlip(self.readout_err, wires=k) + self._apply_operation(bit_flip) diff --git a/pennylane/devices/jittable_qutrit_mixed.py b/pennylane/devices/jittable_qutrit_mixed.py new file mode 100644 index 00000000000..0ade57c0c73 --- /dev/null +++ b/pennylane/devices/jittable_qutrit_mixed.py @@ -0,0 +1,311 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The default.qutrit.mixed device is PennyLane's standard qutrit simulator for mixed-state +computations.""" +import logging +from dataclasses import replace +from typing import Callable, Optional, Sequence, Tuple, Union + +import numpy as np + +import pennylane as qml +from pennylane.logging import debug_logger, debug_logger_init +from pennylane.ops import _qutrit__channel__ops__ as channels +from pennylane.tape import QuantumTape +from pennylane.transforms.core import TransformProgram +from pennylane.typing import Result, ResultBatch + +from . import Device +from .default_qutrit import DefaultQutrit +from .execution_config import DefaultExecutionConfig, ExecutionConfig +from .modifiers import simulator_tracking, single_tape_support +from .preprocess import ( + decompose, + no_sampling, + validate_device_wires, + validate_measurements, + validate_observables, +) +from .qutrit_mixed.simulate import simulate + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +Result_or_ResultBatch = Union[Result, ResultBatch] +QuantumTapeBatch = Sequence[QuantumTape] +QuantumTape_or_Batch = Union[QuantumTape, QuantumTapeBatch] + +# always a function from a resultbatch to either a result or a result batch +PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch] + +observables = { + "THermitian", + "GellMann", +} + + +def observable_stopping_condition(obs: qml.operation.Operator) -> bool: + """Specifies whether an observable is accepted by DefaultQutritMixed.""" + if isinstance(obs, qml.operation.Tensor): + return all(observable_stopping_condition(observable) for observable in obs.obs) + if obs.name in {"Prod", "Sum"}: + return all(observable_stopping_condition(observable) for observable in obs.operands) + if obs.name in {"LinearCombination", "Hamiltonian"}: + return all(observable_stopping_condition(observable) for observable in obs.terms()[1]) + if obs.name == "SProd": + return observable_stopping_condition(obs.base) + + return obs.name in observables + + +def stopping_condition(op: qml.operation.Operator) -> bool: + """Specify whether an Operator object is supported by the device.""" + expected_set = DefaultQutrit.operations | {"Snapshot"} | channels + return op.name in expected_set + + +def stopping_condition_shots(op: qml.operation.Operator) -> bool: + """Specify whether an Operator object is supported by the device with shots.""" + return stopping_condition(op) + + +def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool: + """Specifies whether a measurement is accepted when sampling.""" + return isinstance(m, qml.measurements.SampleMeasurement) + + +@simulator_tracking +@single_tape_support +class DefaultQutritMixed(Device): + """A PennyLane Python-based device for mixed-state qutrit simulation. + + Args: + wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that + contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings + (``['ancilla', 'q1', 'q2']``). Default ``None`` if not specified. + shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots + to use in executions involving this device. + seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A + seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``, or + a request to seed from numpy's global random number generator. + The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` + will pull a seed from the OS entropy. + If a ``jax.random.PRNGKey`` is passed as the seed, a JAX-specific sampling function using + ``jax.random.choice`` and the ``PRNGKey`` will be used for sampling rather than + ``numpy.random.default_rng``. + + **Example:** + + .. code-block:: python + + n_wires = 5 + num_qscripts = 5 + qscripts = [] + for i in range(num_qscripts): + unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() + op = qml.QutritUnitary(unitary, wires=range(n_wires)) + qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) + qscripts.append(qs) + + >>> dev = DefaultQutritMixed() + >>> program, execution_config = dev.preprocess() + >>> new_batch, post_processing_fn = program(qscripts) + >>> results = dev.execute(new_batch, execution_config=execution_config) + >>> post_processing_fn(results) + [0.08015701503959313, + 0.04521414211599359, + -0.0215232130089687, + 0.062120285032425865, + -0.0635052317625] + + This device currently supports backpropagation derivatives: + + >>> from pennylane.devices import ExecutionConfig + >>> dev.supports_derivatives(ExecutionConfig(gradient_method="backprop")) + True + + For example, we can use jax to jit computing the derivative: + + .. code-block:: python + + import jax + + @jax.jit + def f(x): + qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) + program, execution_config = dev.preprocess() + new_batch, post_processing_fn = program([qs]) + results = dev.execute(new_batch, execution_config=execution_config) + return post_processing_fn(results)[0] + + >>> f(jax.numpy.array(1.2)) + DeviceArray(0.36235774, dtype=float32) + >>> jax.grad(f)(jax.numpy.array(1.2)) + DeviceArray(-0.93203914, dtype=float32, weak_type=True) + + .. details:: + :title: Tracking + + ``DefaultQutritMixed`` tracks: + + * ``executions``: the number of unique circuits that would be required on quantum hardware + * ``shots``: the number of shots + * ``resources``: the :class:`~.resource.Resources` for the executed circuit. + * ``simulations``: the number of simulations performed. One simulation can cover multiple QPU executions, such as for non-commuting measurements and batched parameters. + * ``batches``: The number of times :meth:`~.execute` is called. + * ``results``: The results of each call of :meth:`~.execute` + + + """ + + _device_options = ("rng", "prng_key") # tuple of string names for all the device options. + + @property + def name(self): + """The name of the device.""" + return "default.qutrit.mixed" + + @debug_logger_init + def __init__( + self, + wires=None, + shots=None, + seed="global", + ) -> None: + super().__init__(wires=wires, shots=shots) + seed = np.random.randint(0, high=10000000) if seed == "global" else seed + if qml.math.get_interface(seed) == "jax": + self._prng_key = seed + self._rng = np.random.default_rng(None) + else: + self._prng_key = None + self._rng = np.random.default_rng(seed) + self._debugger = None + + @debug_logger + def supports_derivatives( + self, + execution_config: Optional[ExecutionConfig] = None, + circuit: Optional[QuantumTape] = None, + ) -> bool: + """Check whether or not derivatives are available for a given configuration and circuit. + + ``DefaultQutritMixed`` supports backpropagation derivatives with analytic results. + + Args: + execution_config (ExecutionConfig): The configuration of the desired derivative calculation. + circuit (QuantumTape): An optional circuit to check derivatives support for. + + Returns: + bool: Whether or not a derivative can be calculated provided the given information. + + """ + if execution_config is None or execution_config.gradient_method in {"backprop", "best"}: + return circuit is None or not circuit.shots + return False + + def _setup_execution_config(self, execution_config: ExecutionConfig) -> ExecutionConfig: + """This is a private helper for ``preprocess`` that sets up the execution config. + + Args: + execution_config (ExecutionConfig): an unprocessed execution config. + + Returns: + ExecutionConfig: a preprocessed execution config. + """ + updated_values = {} + for option in execution_config.device_options: + if option not in self._device_options: + raise qml.DeviceError(f"device option {option} not present on {self}") + + if execution_config.gradient_method == "best": + updated_values["gradient_method"] = "backprop" + updated_values["use_device_gradient"] = False + updated_values["grad_on_execution"] = False + updated_values["device_options"] = dict(execution_config.device_options) # copy + + for option in self._device_options: + if option not in updated_values["device_options"]: + updated_values["device_options"][option] = getattr(self, f"_{option}") + return replace(execution_config, **updated_values) + + @debug_logger + def preprocess( + self, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ) -> Tuple[TransformProgram, ExecutionConfig]: + """This function defines the device transform program to be applied and an updated device + configuration. + + Args: + execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure + describing the parameters needed to fully describe the execution. + + Returns: + TransformProgram, ExecutionConfig: A transform program that when called returns + ``QuantumTape`` objects that the device can natively execute, as well as a postprocessing + function to be called after execution, and a configuration with unset + specifications filled in. + + This device: + + * Supports any qutrit operations that provide a matrix + * Supports any qutrit channel that provides Kraus matrices + + """ + config = self._setup_execution_config(execution_config) + transform_program = TransformProgram() + + transform_program.add_transform(validate_device_wires, self.wires, name=self.name) + transform_program.add_transform( + decompose, + stopping_condition=stopping_condition, + stopping_condition_shots=stopping_condition_shots, + name=self.name, + ) + transform_program.add_transform( + validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name + ) + transform_program.add_transform( + validate_observables, stopping_condition=observable_stopping_condition, name=self.name + ) + + if config.gradient_method == "backprop": + transform_program.add_transform(no_sampling, name="backprop + default.qutrit") + + return transform_program, config + + @debug_logger + def execute( + self, + circuits: QuantumTape_or_Batch, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ) -> Result_or_ResultBatch: + + interface = ( + execution_config.interface + if execution_config.gradient_method in {"best", "backprop", None} + else None + ) + + return tuple( + simulate( + c, + rng=self._rng, + prng_key=self._prng_key, + debugger=self._debugger, + interface=interface, + ) + for c in circuits + ) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index c115e4ba54a..faad66d384a 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -103,8 +103,6 @@ def apply_operation_einsum(kraus, wires, state): return jnp.einsum(einsum_indices, kraus, state, kraus_dagger) -def f(carry, x): - if x[0] == 0: - qml.TRX().matrix - -flag, wires, inputs \ No newline at end of file +def apply_operation(state, op_info, qudit_dim): + op = jax.lax.switch() + return state, None \ No newline at end of file diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index 237be5412bc..d26f8ce5ab5 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -1,58 +1,30 @@ +import functools +import jax +import jax.numpy as jnp +import pennylane as qml +from pennylane.tape import QuantumScript +from .apply_operations import apply_operation +from ..qutrit_mixed.simulate import measure_final_state +from ..qutrit_mixed.initialize_state import create_initial_state -def simulate( - circuit: qml.tape.QuantumScript, rng=None, prng_key=None, debugger=None, interface=None -) -> Result: - """TODO - - Args: - circuit (QuantumTape): The single circuit to simulate - rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A - seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. - If no value is provided, a default RNG will be used. - prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is - the key to the JAX pseudo random number generator. If None, a random key will be - generated. Only for simulation using JAX. - debugger (_Debugger): The debugger to use - interface (str): The machine learning interface to create the initial state with - - Returns: - tuple(TensorLike): The results of the simulation - - Note that this function can return measurements for non-commuting observables simultaneously. - - This function assumes that all operations provide matrices. - - >>> qs = qml.tape.QuantumScript([qml.TRX(1.2, wires=0)], [qml.expval(qml.GellMann(0, 3)), qml.probs(wires=(0,1))]) - >>> simulate(qs) - (0.36235775447667357, - tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) +op_types = [] - """ - state, is_state_batched = get_final_state( - circuit, debugger=debugger, interface=interface, rng=rng, prng_key=prng_key - ) - return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key) -def get_final_state(circuit, debugger=None, interface=None, **kwargs): +def get_final_state_from_initial(operations, initial_state, qudit_dim): """ TODO Args: circuit (.QuantumScript): The single circuit to simulate - debugger (._Debugger): The debugger to use - interface (str): The machine learning interface to create the initial state with Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and whether the state has a batch dimension. """ - circuit = circuit.map_to_standard_wires() - state = create_initial_state(sorted(circuit.op_wires), like="jax") - - op_array = [] - for op in circuit.operations: + ops_type_index, ops_wires, ops_params = [], [[], []], [[], [], []] + for op in operations: # op_index = None # for i, op_type in enumerate(op_types): @@ -61,26 +33,42 @@ def get_final_state(circuit, debugger=None, interface=None, **kwargs): # if op_index is None: # raise ValueError("This simulator only supports") - op_index = op_types.index(type(op)) + # op_index = op_types.index(type(op)) + # wires = op.wires() + # if len(wires) == 1: + # wires = [-1, wires[0]] + # if len(wires) == 2: + # wires = list(wires) + # params = op.parameters + ([0] * (3-op.num_params)) + # op_array.append([[op_index] + wires, params]) + ops_type_index.append(op_types.index(type(op))) wires = op.wires() if len(wires) == 1: wires = [-1, wires[0]] - if len(wires) == 2: - wires = list(wires) - params = op.parameters + ([0] * (3-op.num_params)) - op_array.append([[op_index] + wires, params]) + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) + + params = op.parameters + ([0] * (3 - op.num_params)) + ops_params[0].append(params[0]) + ops_params[1].append(params[1]) + ops_params[2].append(params[2]) - return state + ops_info = { + "type_index": jnp.array(ops_type_index), + "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], + "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])] + } + return jax.lax.scan(functools.partial(apply_operation, qudit_dim), initial_state, ops_info)[0] -def get_final_state(circuit, debugger=None, interface=None, **kwargs): +def get_final_state(circuit): """ TODO Args: circuit (.QuantumScript): The single circuit to simulate - debugger (._Debugger): The debugger to use + qudit_dim (): TODO interface (str): The machine learning interface to create the initial state with Returns: @@ -88,38 +76,46 @@ def get_final_state(circuit, debugger=None, interface=None, **kwargs): whether the state has a batch dimension. """ + circuit = circuit.map_to_standard_wires() prep = None if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): prep = circuit[0] - state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface]) - - # initial state is batched only if the state preparation (if it exists) is batched - is_state_batched = bool(prep and prep.batch_size is not None) - for op in circuit.operations[bool(prep) :]: - state = apply_operation( - op, - state, - is_state_batched=is_state_batched, - debugger=debugger, - tape_shots=circuit.shots, - **kwargs, - ) - - # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim - is_state_batched = is_state_batched or op.batch_size is not None - - num_operated_wires = len(circuit.op_wires) - for i in range(len(circuit.wires) - num_operated_wires): - # If any measured wires are not operated on, we pad the density matrix with zeros. - # We know they belong at the end because the circuit is in standard wire-order - # Since it is a dm, we must pad it with 0s on the last row and last column - current_axis = num_operated_wires + i + is_state_batched - state = qml.math.stack( - ([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=current_axis - ) - state = qml.math.stack(([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=-1) - - return state, is_state_batched \ No newline at end of file + state = create_initial_state(sorted(circuit.op_wires), prep, like="jax") + get_final_state_from_initial(circuit.operations[bool(prep):], state, 3) + + + + + +def simulate(circuit: QuantumScript, rng=None, prng_key=None): + """TODO + + Args: + circuit (QuantumTape): The single circuit to simulate + rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A + seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. + If no value is provided, a default RNG will be used. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. If None, a random key will be + generated. Only for simulation using JAX. + debugger (_Debugger): The debugger to use + interface (str): The machine learning interface to create the initial state with + + Returns: + tuple(TensorLike): The results of the simulation + + Note that this function can return measurements for non-commuting observables simultaneously. + + This function assumes that all operations provide matrices. + + >>> qs = qml.tape.QuantumScript([qml.TRX(1.2, wires=0)], [qml.expval(qml.GellMann(0, 3)), qml.probs(wires=(0,1))]) + >>> simulate(qs) + (0.36235775447667357, + tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) + + """ + state, is_state_batched = get_final_state(circuit) + return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key) From 91f2cf3d12fbb7db225770998946e7bfe32c872c Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Tue, 2 Jul 2024 00:49:08 -0700 Subject: [PATCH 03/26] Working towards new devices --- .../qtcorgi_helper/apply_operations.py | 21 ++++++++++++++++--- .../qtcorgi_helper/qtcorgi_simulator.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index faad66d384a..0e5e70f0920 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -103,6 +103,21 @@ def apply_operation_einsum(kraus, wires, state): return jnp.einsum(einsum_indices, kraus, state, kraus_dagger) -def apply_operation(state, op_info, qudit_dim): - op = jax.lax.switch() - return state, None \ No newline at end of file +def apply_single_qudit_unitary(): + pass + + +def apply_two_qudit_unitary(): + pass + + +def apply_single_qudit_channel(): + pass + + +def apply_operation(state, qudit_dim, op_info): + # TODO may have to rewrite to return different functions for qubits and qutrits + op_i = op_info["type_index"] + op_class = op_i // first_index + op_i // second_index + op_i // third_index + state = jax.lax.switch(op_class, [], qudit_dim, op_info) + return state, None diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index d26f8ce5ab5..c1cec88b5fb 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -59,7 +59,7 @@ def get_final_state_from_initial(operations, initial_state, qudit_dim): "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])] } - return jax.lax.scan(functools.partial(apply_operation, qudit_dim), initial_state, ops_info)[0] + return jax.lax.scan(apply_operation, initial_state, qudit_dim, ops_info)[0] def get_final_state(circuit): From 267076137fd1fab44a07aa1ca7f4e11aade38b29 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Tue, 2 Jul 2024 15:37:20 -0700 Subject: [PATCH 04/26] Added logic to apply_operations and working on splitting qubit and qutrit simulators --- .../qtcorgi_helper/apply_operations.py | 89 +++++++++++++++++-- .../qtcorgi_helper/qtcorgi_simulator.py | 81 +++++++++++++++-- 2 files changed, 154 insertions(+), 16 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 0e5e70f0920..aedcfd35847 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -103,21 +103,92 @@ def apply_operation_einsum(kraus, wires, state): return jnp.einsum(einsum_indices, kraus, state, kraus_dagger) -def apply_single_qudit_unitary(): +def get_two_qubit_unitary_matrix(): pass -def apply_two_qudit_unitary(): +def get_CNOT_matrix(params): + return jnp.array([[1,0,0,0], + [0,1,0,0], + [0,0,0,1], + [0,0,1,0]]) + + +single_qubit_ops = [qml.RX.compute_matrix, qml.RY.compute_matrix, qml.RZ.compute_matrix] +two_qubit_ops = [get_CNOT_matrix, get_two_qubit_unitary_matrix] +single_qubit_channels = [qml.DepolarizingChannel.compute_kraus_matrices, qml.AmplitudeDamping.compute_kraus_matrices, qml.BitFlip.compute_kraus_matrices] + +def apply_single_qubit_unitary(state, op_info): + wire, param = op_info["wires"][0], op_info["params"][0] + kraus_mat = jax.lax.switch(op_info["type_index"], single_qubit_ops, param) + pass + + +def apply_two_qubit_unitary(state, op_info): + wires, params = op_info["wires"], op_info["params"] + kraus_mat = jax.lax.switch(op_info["type_index"], two_qubit_ops, params) + pass + + +def apply_single_qubit_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["params"][0] + kraus_mat = jax.lax.switch(op_info["type_index"], single_qubit_channels, param) pass -def apply_single_qudit_channel(): +single_qutrit_ops = [qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix] +single_qutrit_channels = [ + lambda params: qml.QutritDepolarizingChannel.compute_kraus_matrices(params[0]), + lambda params: qml.QutritAmplitudeDamping.compute_kraus_matrices(*params), + lambda params: qml.TritFlip.compute_kraus_matrices(*params), +] + + +def apply_single_qutrit_unitary(state, op_info): + wire, param = op_info["wires"][0], op_info["params"][0] + kraus_mat = jax.lax.switch(op_info["type_index"], single_qutrit_ops, param) + pass + + +def apply_two_qutrit_unitary(state, op_info): + wires = op_info["wires"] + kraus_mat = jnp.array([[1,0,0,0,0,0,0,0,0], + [0,1,0,0,0,0,0,0,0], + [0,0,1,0,0,0,0,0,0], + [0,0,0,0,0,1,0,0,0], + [0,0,0,1,0,0,0,0,0], + [0,0,0,0,1,0,0,0,0], + [0,0,0,0,0,0,0,1,0], + [0,0,0,0,0,0,0,0,1], + [0,0,0,0,0,0,1,0,0]]) pass -def apply_operation(state, qudit_dim, op_info): - # TODO may have to rewrite to return different functions for qubits and qutrits - op_i = op_info["type_index"] - op_class = op_i // first_index + op_i // second_index + op_i // third_index - state = jax.lax.switch(op_class, [], qudit_dim, op_info) - return state, None +def apply_single_qutrit_channel(state, op_info): + wire, params = op_info["wires"][0], op_info["params"] # TODO qutrit channels take 3 params + kraus_mat = jax.lax.switch(op_info["type_index"], single_qutrit_channels, *params) + pass + + +def get_operation_applier(qudit_dim): + qubit_type_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, + apply_single_qubit_channel] + qutrit_type_branches = [apply_single_qutrit_unitary, apply_two_qutrit_unitary, + apply_single_qutrit_channel] + if qudit_dim == 2: + def operation_applier(state, op_info): + index_cutoffs = [0, 0, 0] + op_i = op_info["type_index"] + op_class = op_i // index_cutoffs[0] + op_i // index_cutoffs[1] + op_i // index_cutoffs[2] + return jax.lax.switch(op_class, qubit_type_branches, state, op_info), None + elif qudit_dim == 3: + def operation_applier(state, op_info): + index_cutoffs = [0, 0, 0] + op_i = op_info["type_index"] + op_class = op_i // index_cutoffs[0] + op_i // index_cutoffs[1] + op_i // index_cutoffs[2] + return jax.lax.switch(op_class, qutrit_type_branches, state, op_info), None + else: + raise ValueError("Only qubit and qutrit simulators are allowed") + + return operation_applier + diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index c1cec88b5fb..b889ca72d89 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -3,15 +3,74 @@ import jax.numpy as jnp import pennylane as qml from pennylane.tape import QuantumScript -from .apply_operations import apply_operation +from .apply_operations import get_operation_applier from ..qutrit_mixed.simulate import measure_final_state from ..qutrit_mixed.initialize_state import create_initial_state op_types = [] +def get_qubit_final_state_from_initial(operations, initial_state): + """ + TODO + + Args: + circuit (.QuantumScript): The single circuit to simulate + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. -def get_final_state_from_initial(operations, initial_state, qudit_dim): + """ + ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] + for op in operations: + + wires = op.wires() + + + if isinstance(op, qml.operation.Channel): + ops_type_indices[0].append(2) + ops_type_indices[1].append([].index(type(op))) + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([].index(type(op))) + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([].index(type(op))) + + + + if len(wires) == 1: + wires = [wires[0], -1] + params = op.parameters + ([0] * (3 - op.num_params)) + if len(wires) == 2: + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) + + ops_param[0].append(params[0]) + + op_index = op_types.index(type(op)) + ops_type_index.append(op_index) + + + + if qudit_dim == 2 and op_index <= 2: + ops_subspace.append([(0,1), (0,2), (1,2)].index(op.subspace)) + else: + ops_subspace.append(0) + + ops_info = { + "type_index": jnp.array(ops_type_index), + "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], + "params": [jnp.array(ops_param)] + } + + + + return jax.lax.scan(get_operation_applier(qudit_dim), initial_state, ops_info)[0] + + +def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim): """ TODO @@ -23,7 +82,7 @@ def get_final_state_from_initial(operations, initial_state, qudit_dim): whether the state has a batch dimension. """ - ops_type_index, ops_wires, ops_params = [], [[], []], [[], [], []] + ops_type_index, ops_subspace, ops_wires, ops_params = [], [], [[], []], [[], [], []] for op in operations: # op_index = None @@ -42,24 +101,32 @@ def get_final_state_from_initial(operations, initial_state, qudit_dim): # params = op.parameters + ([0] * (3-op.num_params)) # op_array.append([[op_index] + wires, params]) - ops_type_index.append(op_types.index(type(op))) + op_index = op_types.index(type(op)) + ops_type_index.append(op_index) + wires = op.wires() if len(wires) == 1: - wires = [-1, wires[0]] + wires = [wires[0], -1] + params = op.parameters + ([0] * (3 - op.num_params)) + if len(wires) == 2: ops_wires[0].append(wires[0]) ops_wires[1].append(wires[1]) - params = op.parameters + ([0] * (3 - op.num_params)) ops_params[0].append(params[0]) ops_params[1].append(params[1]) ops_params[2].append(params[2]) + if qudit_dim == 2 and op_index <= 2: + ops_subspace.append([(0,1), (0,2), (1,2)].index(op.subspace)) + else: + ops_subspace.append(0) + ops_info = { "type_index": jnp.array(ops_type_index), "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])] } - return jax.lax.scan(apply_operation, initial_state, qudit_dim, ops_info)[0] + return jax.lax.scan(get_operation_applier(qudit_dim), initial_state, ops_info)[0] def get_final_state(circuit): From f1679f4ee5d6ef506ad1f26b3caeece53e6e9e04 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Tue, 2 Jul 2024 20:08:11 -0700 Subject: [PATCH 05/26] Linked new jittable simulation to old device --- pennylane/devices/default_mixed.py | 15 ++- .../qtcorgi_helper/apply_operations.py | 75 ++++++------ .../qtcorgi_helper/qtcorgi_simulator.py | 109 +++++------------- pennylane/devices/qutrit_mixed/simulate.py | 58 +--------- 4 files changed, 78 insertions(+), 179 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 0e2d0fbfcf0..97c11febdf7 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -48,6 +48,7 @@ from pennylane.wires import Wires from .._version import __version__ +from .qtcorgi_helper.qtcorgi_simulator import get_qubit_final_state_from_initial logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -781,9 +782,17 @@ def apply(self, operations, rotations=None, **kwargs): f"on a {self.short_name} device." ) - for operation in operations: - self._apply_operation(operation) - + prep = False + if len(operations) > 0: + if ( + isinstance(operations[0], StatePrep) + or isinstance(operations[0], BasisState) + or isinstance(operations[0], QubitDensityMatrix) + ): + prep = True + self._apply_operation(operation) + + self._state = get_qubit_final_state_from_initial(operations[prep:], self._state) # store the pre-rotated state self._pre_rotated_state = self._state diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index aedcfd35847..3e9635fc04d 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -2,14 +2,17 @@ import jax import jax.numpy as jnp from jax.lax import scan + jax.config.update("jax_enable_x64", True) jax.config.update("jax_platforms", "cpu") import pennylane as qml from string import ascii_letters as alphabet import numpy as np + alphabet_array = np.array(list(alphabet)) + def get_einsum_mapping(wires, state): r"""Finds the indices for einsum to apply kraus operators to a mixed state @@ -52,9 +55,8 @@ def get_einsum_mapping(wires, state): state_indices=state_indices, ) # index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef' - return ( - f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}" - ) + return f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}" + def get_new_state_einsum_indices(old_indices, new_indices, state_indices): """Retrieves the einsum indices string for the new state @@ -73,7 +75,10 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): state_indices, ) + QUDIT_DIM = 3 + + def apply_operation_einsum(kraus, wires, state): r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the quantum state. For a unitary gate, there is a single Kraus operator. @@ -108,15 +113,17 @@ def get_two_qubit_unitary_matrix(): def get_CNOT_matrix(params): - return jnp.array([[1,0,0,0], - [0,1,0,0], - [0,0,0,1], - [0,0,1,0]]) + return jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) single_qubit_ops = [qml.RX.compute_matrix, qml.RY.compute_matrix, qml.RZ.compute_matrix] two_qubit_ops = [get_CNOT_matrix, get_two_qubit_unitary_matrix] -single_qubit_channels = [qml.DepolarizingChannel.compute_kraus_matrices, qml.AmplitudeDamping.compute_kraus_matrices, qml.BitFlip.compute_kraus_matrices] +single_qubit_channels = [ + qml.DepolarizingChannel.compute_kraus_matrices, + qml.AmplitudeDamping.compute_kraus_matrices, + qml.BitFlip.compute_kraus_matrices, +] + def apply_single_qubit_unitary(state, op_info): wire, param = op_info["wires"][0], op_info["params"][0] @@ -136,6 +143,9 @@ def apply_single_qubit_channel(state, op_info): pass +qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] + + single_qutrit_ops = [qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix] single_qutrit_channels = [ lambda params: qml.QutritDepolarizingChannel.compute_kraus_matrices(params[0]), @@ -152,15 +162,19 @@ def apply_single_qutrit_unitary(state, op_info): def apply_two_qutrit_unitary(state, op_info): wires = op_info["wires"] - kraus_mat = jnp.array([[1,0,0,0,0,0,0,0,0], - [0,1,0,0,0,0,0,0,0], - [0,0,1,0,0,0,0,0,0], - [0,0,0,0,0,1,0,0,0], - [0,0,0,1,0,0,0,0,0], - [0,0,0,0,1,0,0,0,0], - [0,0,0,0,0,0,0,1,0], - [0,0,0,0,0,0,0,0,1], - [0,0,0,0,0,0,1,0,0]]) + kraus_mat = jnp.array( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + ] + ) pass @@ -170,25 +184,8 @@ def apply_single_qutrit_channel(state, op_info): pass -def get_operation_applier(qudit_dim): - qubit_type_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, - apply_single_qubit_channel] - qutrit_type_branches = [apply_single_qutrit_unitary, apply_two_qutrit_unitary, - apply_single_qutrit_channel] - if qudit_dim == 2: - def operation_applier(state, op_info): - index_cutoffs = [0, 0, 0] - op_i = op_info["type_index"] - op_class = op_i // index_cutoffs[0] + op_i // index_cutoffs[1] + op_i // index_cutoffs[2] - return jax.lax.switch(op_class, qubit_type_branches, state, op_info), None - elif qudit_dim == 3: - def operation_applier(state, op_info): - index_cutoffs = [0, 0, 0] - op_i = op_info["type_index"] - op_class = op_i // index_cutoffs[0] + op_i // index_cutoffs[1] + op_i // index_cutoffs[2] - return jax.lax.switch(op_class, qutrit_type_branches, state, op_info), None - else: - raise ValueError("Only qubit and qutrit simulators are allowed") - - return operation_applier - +qutrit_branches = [ + apply_single_qutrit_unitary, + apply_two_qutrit_unitary, + apply_single_qutrit_channel, +] diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index b889ca72d89..d17c2436712 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -1,10 +1,8 @@ -import functools import jax import jax.numpy as jnp import pennylane as qml from pennylane.tape import QuantumScript -from .apply_operations import get_operation_applier -from ..qutrit_mixed.simulate import measure_final_state +from .apply_operations import qubit_branches, qutrit_branches from ..qutrit_mixed.initialize_state import create_initial_state op_types = [] @@ -15,7 +13,7 @@ def get_qubit_final_state_from_initial(operations, initial_state): TODO Args: - circuit (.QuantumScript): The single circuit to simulate + TODO Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -27,7 +25,6 @@ def get_qubit_final_state_from_initial(operations, initial_state): wires = op.wires() - if isinstance(op, qml.operation.Channel): ops_type_indices[0].append(2) ops_type_indices[1].append([].index(type(op))) @@ -38,13 +35,10 @@ def get_qubit_final_state_from_initial(operations, initial_state): ops_type_indices[0].append(0) ops_type_indices[1].append([].index(type(op))) - - if len(wires) == 1: wires = [wires[0], -1] params = op.parameters + ([0] * (3 - op.num_params)) - if len(wires) == 2: - ops_wires[0].append(wires[0]) + ops_wires[0].append(wires[0]) ops_wires[1].append(wires[1]) ops_param[0].append(params[0]) @@ -52,30 +46,28 @@ def get_qubit_final_state_from_initial(operations, initial_state): op_index = op_types.index(type(op)) ops_type_index.append(op_index) - - - if qudit_dim == 2 and op_index <= 2: - ops_subspace.append([(0,1), (0,2), (1,2)].index(op.subspace)) - else: - ops_subspace.append(0) - ops_info = { "type_index": jnp.array(ops_type_index), "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_param)] + "params": [jnp.array(ops_param)], } + return jax.lax.scan( + lambda state, op_info: ( + jax.lax.switch(op_info["branch"], qubit_branches, state, op_info), + None, + ), + initial_state, + ops_info, + )[0] - return jax.lax.scan(get_operation_applier(qudit_dim), initial_state, ops_info)[0] - - -def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim): +def get_qutrit_final_state_from_initial(operations, initial_state): """ TODO Args: - circuit (.QuantumScript): The single circuit to simulate + TODO Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -84,23 +76,6 @@ def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim): """ ops_type_index, ops_subspace, ops_wires, ops_params = [], [], [[], []], [[], [], []] for op in operations: - - # op_index = None - # for i, op_type in enumerate(op_types): - # if isinstance(op, op_type): - # op_index = i - # if op_index is None: - # raise ValueError("This simulator only supports") - - # op_index = op_types.index(type(op)) - # wires = op.wires() - # if len(wires) == 1: - # wires = [-1, wires[0]] - # if len(wires) == 2: - # wires = list(wires) - # params = op.parameters + ([0] * (3-op.num_params)) - # op_array.append([[op_index] + wires, params]) - op_index = op_types.index(type(op)) ops_type_index.append(op_index) @@ -108,7 +83,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim): if len(wires) == 1: wires = [wires[0], -1] params = op.parameters + ([0] * (3 - op.num_params)) - if len(wires) == 2: + ops_wires[0].append(wires[0]) ops_wires[1].append(wires[1]) @@ -116,27 +91,31 @@ def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim): ops_params[1].append(params[1]) ops_params[2].append(params[2]) - if qudit_dim == 2 and op_index <= 2: - ops_subspace.append([(0,1), (0,2), (1,2)].index(op.subspace)) + if op_index <= 2: + ops_subspace.append([(0, 1), (0, 2), (1, 2)].index(op.subspace)) else: ops_subspace.append(0) ops_info = { "type_index": jnp.array(ops_type_index), "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])] + "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], } - return jax.lax.scan(get_operation_applier(qudit_dim), initial_state, ops_info)[0] + op_branch = jnp.nan + return jax.lax.scan( + lambda state, op_info: (jax.lax.switch(op_info["branch"], qutrit_branches, state, x), None), + initial_state, + ops_info, + )[0] -def get_final_state(circuit): + +def get_final_state_qutrit(circuit): """ TODO Args: circuit (.QuantumScript): The single circuit to simulate - qudit_dim (): TODO - interface (str): The machine learning interface to create the initial state with Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -151,38 +130,4 @@ def get_final_state(circuit): prep = circuit[0] state = create_initial_state(sorted(circuit.op_wires), prep, like="jax") - get_final_state_from_initial(circuit.operations[bool(prep):], state, 3) - - - - - -def simulate(circuit: QuantumScript, rng=None, prng_key=None): - """TODO - - Args: - circuit (QuantumTape): The single circuit to simulate - rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A - seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. - If no value is provided, a default RNG will be used. - prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is - the key to the JAX pseudo random number generator. If None, a random key will be - generated. Only for simulation using JAX. - debugger (_Debugger): The debugger to use - interface (str): The machine learning interface to create the initial state with - - Returns: - tuple(TensorLike): The results of the simulation - - Note that this function can return measurements for non-commuting observables simultaneously. - - This function assumes that all operations provide matrices. - - >>> qs = qml.tape.QuantumScript([qml.TRX(1.2, wires=0)], [qml.expval(qml.GellMann(0, 3)), qml.probs(wires=(0,1))]) - >>> simulate(qs) - (0.36235775447667357, - tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) - - """ - state, is_state_batched = get_final_state(circuit) - return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key) + return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state), False diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 54d4f35c28e..17a665c39b6 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -23,6 +23,7 @@ from .measure import measure from .sampling import measure_with_samples from .utils import QUDIT_DIM +from ..qtcorgi_helper.qtcorgi_simulator import get_qutrit_final_state INTERFACE_TO_LIKE = { # map interfaces known by autoray to themselves @@ -45,59 +46,6 @@ } -def get_final_state(circuit, debugger=None, interface=None, **kwargs): - """ - Get the final state that results from executing the given quantum script. - - This is an internal function that will be called by ``default.qutrit.mixed``. - - Args: - circuit (.QuantumScript): The single circuit to simulate - debugger (._Debugger): The debugger to use - interface (str): The machine learning interface to create the initial state with - - Returns: - Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and - whether the state has a batch dimension. - - """ - circuit = circuit.map_to_standard_wires() - - prep = None - if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): - prep = circuit[0] - - state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface]) - - # initial state is batched only if the state preparation (if it exists) is batched - is_state_batched = bool(prep and prep.batch_size is not None) - for op in circuit.operations[bool(prep) :]: - state = apply_operation( - op, - state, - is_state_batched=is_state_batched, - debugger=debugger, - tape_shots=circuit.shots, - **kwargs, - ) - - # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim - is_state_batched = is_state_batched or op.batch_size is not None - - num_operated_wires = len(circuit.op_wires) - for i in range(len(circuit.wires) - num_operated_wires): - # If any measured wires are not operated on, we pad the density matrix with zeros. - # We know they belong at the end because the circuit is in standard wire-order - # Since it is a dm, we must pad it with 0s on the last row and last column - current_axis = num_operated_wires + i + is_state_batched - state = qml.math.stack( - ([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=current_axis - ) - state = qml.math.stack(([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=-1) - - return state, is_state_batched - - def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=None) -> Result: """ Perform the measurements required by the circuit on the provided state. @@ -182,7 +130,7 @@ def simulate( tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) """ - state, is_state_batched = get_final_state( + state = get_qutrit_final_state( circuit, debugger=debugger, interface=interface, rng=rng, prng_key=prng_key ) - return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key) + return measure_final_state(circuit, state, False, rng=rng, prng_key=prng_key) From a11206c5f3fc32f901abd828bd39a48358d148ec Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Thu, 4 Jul 2024 18:04:59 -0700 Subject: [PATCH 06/26] Fixed indexing and linked apply operation. --- .../qtcorgi_helper/apply_operations.py | 73 +++++++++------ .../qtcorgi_helper/qtcorgi_simulator.py | 92 +++++++++---------- pennylane/devices/qutrit_mixed/__init__.py | 2 +- pennylane/devices/qutrit_mixed/simulate.py | 27 +++++- 4 files changed, 110 insertions(+), 84 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 3e9635fc04d..0f6fc3ce216 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -9,6 +9,7 @@ from string import ascii_letters as alphabet import numpy as np +from functools import partial, reduce alphabet_array = np.array(list(alphabet)) @@ -69,7 +70,7 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): Returns: str: The einsum indices of the new state """ - return functools.reduce( + return reduce( lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), zip(old_indices, new_indices), state_indices, @@ -85,6 +86,7 @@ def apply_operation_einsum(kraus, wires, state): Args: kraus (??): TODO + wires state (array[complex]): Input quantum state Returns: @@ -126,27 +128,36 @@ def get_CNOT_matrix(params): def apply_single_qubit_unitary(state, op_info): - wire, param = op_info["wires"][0], op_info["params"][0] - kraus_mat = jax.lax.switch(op_info["type_index"], single_qubit_ops, param) - pass + wires, param = op_info["wires"][:0], op_info["params"][0] + kraus_mat = jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param) + return apply_operation_einsum(kraus_mat, wires, state) def apply_two_qubit_unitary(state, op_info): wires, params = op_info["wires"], op_info["params"] - kraus_mat = jax.lax.switch(op_info["type_index"], two_qubit_ops, params) - pass + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qubit_ops, params)] + return apply_operation_einsum(kraus_mats, wires, state) def apply_single_qubit_channel(state, op_info): - wire, param = op_info["wires"][0], op_info["params"][0] - kraus_mat = jax.lax.switch(op_info["type_index"], single_qubit_channels, param) - pass + wires, param = op_info["wires"][:0], op_info["params"][0] + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qubit_channels, param)] + return apply_operation_einsum(kraus_mats, wires, state) qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] -single_qutrit_ops = [qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix] +single_qutrit_ops = [ + qml.TRX.compute_matrix, + qml.TRY.compute_matrix, + qml.TRZ.compute_matrix, + lambda params: ( + qml.THadamard.compute_matrix() + if params[1] != 0 + else qml.THadamard.compute_matrix(subspace=params[1:]) + ), +] single_qutrit_channels = [ lambda params: qml.QutritDepolarizingChannel.compute_kraus_matrices(params[0]), lambda params: qml.QutritAmplitudeDamping.compute_kraus_matrices(*params), @@ -155,33 +166,35 @@ def apply_single_qubit_channel(state, op_info): def apply_single_qutrit_unitary(state, op_info): - wire, param = op_info["wires"][0], op_info["params"][0] - kraus_mat = jax.lax.switch(op_info["type_index"], single_qutrit_ops, param) - pass + wires, param = op_info["wires"][:0], op_info["params"][0] + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_ops, param)] + return apply_operation_einsum(kraus_mats, wires, state) def apply_two_qutrit_unitary(state, op_info): wires = op_info["wires"] - kraus_mat = jnp.array( - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 1, 0, 0], - ] - ) - pass + kraus_mat = [ + jnp.array( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + ] + ) + ] + return apply_operation_einsum(kraus_mat, wires, state) def apply_single_qutrit_channel(state, op_info): - wire, params = op_info["wires"][0], op_info["params"] # TODO qutrit channels take 3 params - kraus_mat = jax.lax.switch(op_info["type_index"], single_qutrit_channels, *params) - pass + wires, params = op_info["wires"][:0], op_info["params"] # TODO qutrit channels take 3 params + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_channels, *params)] + return apply_operation_einsum(kraus_mats, wires, state) qutrit_branches = [ diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index d17c2436712..a4073d30e07 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -1,9 +1,8 @@ import jax import jax.numpy as jnp import pennylane as qml -from pennylane.tape import QuantumScript +from pennylane.operation import Channel from .apply_operations import qubit_branches, qutrit_branches -from ..qutrit_mixed.initialize_state import create_initial_state op_types = [] @@ -25,15 +24,17 @@ def get_qubit_final_state_from_initial(operations, initial_state): wires = op.wires() - if isinstance(op, qml.operation.Channel): + if isinstance(op, Channel): ops_type_indices[0].append(2) ops_type_indices[1].append([].index(type(op))) elif len(wires) == 1: ops_type_indices[0].append(0) - ops_type_indices[1].append([].index(type(op))) - elif len(wires) == 1: - ops_type_indices[0].append(0) - ops_type_indices[1].append([].index(type(op))) + ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op))) + elif len(wires) == 2: + ops_type_indices[0].append(1) + ops_type_indices[1].append(0) # Assume always CNOT + else: + raise ValueError("TODO") if len(wires) == 1: wires = [wires[0], -1] @@ -43,18 +44,15 @@ def get_qubit_final_state_from_initial(operations, initial_state): ops_param[0].append(params[0]) - op_index = op_types.index(type(op)) - ops_type_index.append(op_index) - ops_info = { - "type_index": jnp.array(ops_type_index), + "type_index": jnp.array(ops_type_indices), "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], "params": [jnp.array(ops_param)], } return jax.lax.scan( lambda state, op_info: ( - jax.lax.switch(op_info["branch"], qubit_branches, state, op_info), + jax.lax.switch(op_info["type_indices"][0], qubit_branches, state, op_info), None, ), initial_state, @@ -74,60 +72,52 @@ def get_qutrit_final_state_from_initial(operations, initial_state): whether the state has a batch dimension. """ - ops_type_index, ops_subspace, ops_wires, ops_params = [], [], [[], []], [[], [], []] + ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] for op in operations: - op_index = op_types.index(type(op)) - ops_type_index.append(op_index) wires = op.wires() - if len(wires) == 1: - wires = [wires[0], -1] - params = op.parameters + ([0] * (3 - op.num_params)) - - ops_wires[0].append(wires[0]) - ops_wires[1].append(wires[1]) + if isinstance(op, Channel): + ops_type_indices[0].append(2) + ops_type_indices[1].append( + [qml.QutritDepolarizingChannel, qml.QutritAmplitudeDamping, qml.TritFlip].index( + type(op) + ) + ) + params = op.parameters + ([0] * (3 - op.num_params)) + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) + if ops_type_indices[1][-1] == 3: + params = [0] + list(op.subspace) if op.subspace is not None else [0, 0] + else: + params = list(op.params) + list(op.subspace) + elif len(wires) == 2: + ops_type_indices[0].append(1) + ops_type_indices[1].append(0) # Assume always TAdd + params = [0, 0, 0] + else: + raise ValueError("TODO") ops_params[0].append(params[0]) ops_params[1].append(params[1]) ops_params[2].append(params[2]) - if op_index <= 2: - ops_subspace.append([(0, 1), (0, 2), (1, 2)].index(op.subspace)) - else: - ops_subspace.append(0) + if len(wires) == 1: + wires = [wires[0], -1] + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) ops_info = { - "type_index": jnp.array(ops_type_index), + "type_indices": jnp.array(ops_type_indices), "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], } - op_branch = jnp.nan return jax.lax.scan( - lambda state, op_info: (jax.lax.switch(op_info["branch"], qutrit_branches, state, x), None), + lambda state, op_info: ( + jax.lax.switch(op_info["type_indices"][0], qutrit_branches, state, op_info), + None, + ), initial_state, ops_info, )[0] - - -def get_final_state_qutrit(circuit): - """ - TODO - - Args: - circuit (.QuantumScript): The single circuit to simulate - - Returns: - Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and - whether the state has a batch dimension. - - """ - - circuit = circuit.map_to_standard_wires() - - prep = None - if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): - prep = circuit[0] - - state = create_initial_state(sorted(circuit.op_wires), prep, like="jax") - return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state), False diff --git a/pennylane/devices/qutrit_mixed/__init__.py b/pennylane/devices/qutrit_mixed/__init__.py index 192b5a1b65a..e171789454e 100644 --- a/pennylane/devices/qutrit_mixed/__init__.py +++ b/pennylane/devices/qutrit_mixed/__init__.py @@ -34,4 +34,4 @@ from .initialize_state import create_initial_state from .measure import measure from .sampling import sample_state, measure_with_samples -from .simulate import simulate, get_final_state, measure_final_state +from .simulate import simulate, measure_final_state diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 17a665c39b6..d820477dc8c 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -23,7 +23,7 @@ from .measure import measure from .sampling import measure_with_samples from .utils import QUDIT_DIM -from ..qtcorgi_helper.qtcorgi_simulator import get_qutrit_final_state +from ..qtcorgi_helper.qtcorgi_simulator import get_qutrit_final_state_from_initial INTERFACE_TO_LIKE = { # map interfaces known by autoray to themselves @@ -99,6 +99,29 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non return results +def get_final_state_qutrit(circuit, **kwargs): + """ + TODO + + Args: + circuit (.QuantumScript): The single circuit to simulate + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + + circuit = circuit.map_to_standard_wires() + + prep = None + if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): + prep = circuit[0] + + state = create_initial_state(sorted(circuit.op_wires), prep, like="jax") + return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state), False + + def simulate( circuit: qml.tape.QuantumScript, rng=None, prng_key=None, debugger=None, interface=None ) -> Result: @@ -130,7 +153,7 @@ def simulate( tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) """ - state = get_qutrit_final_state( + state = get_final_state_qutrit( circuit, debugger=debugger, interface=interface, rng=rng, prng_key=prng_key ) return measure_final_state(circuit, state, False, rng=rng, prng_key=prng_key) From f1c0749424856cebc8983fe4689d56153a92ba4c Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Fri, 5 Jul 2024 14:59:31 -0700 Subject: [PATCH 07/26] Removed extra files --- doing_a_test.py | 117 --- my_testing.py | 97 --- pennylane/devices/jittable_mixed.py | 797 --------------------- pennylane/devices/jittable_qutrit_mixed.py | 311 -------- 4 files changed, 1322 deletions(-) delete mode 100644 doing_a_test.py delete mode 100644 my_testing.py delete mode 100644 pennylane/devices/jittable_mixed.py delete mode 100644 pennylane/devices/jittable_qutrit_mixed.py diff --git a/doing_a_test.py b/doing_a_test.py deleted file mode 100644 index ec60c54ce61..00000000000 --- a/doing_a_test.py +++ /dev/null @@ -1,117 +0,0 @@ -# import jax -import pennylane as qml -import numpy as np -# import scipy -# -# n_wires = 2 -# num_qscripts = 2 -# qscripts = [] -# for i in range(num_qscripts): -# unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() -# op = qml.QutritUnitary(unitary, wires=range(n_wires)) -# qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) -# qscripts.append(qs) -# -# dev = qml.devices.DefaultQutritMixed() -# program, execution_config = dev.preprocess() -# new_batch, post_processing_fn = program(qscripts) -# results = dev.execute(new_batch, execution_config=execution_config) -# print(post_processing_fn(results)) -# -# @jax.jit -# def f(x): -# qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) -# program, execution_config = dev.preprocess() -# new_batch, post_processing_fn = program([qs]) -# results = dev.execute(new_batch, execution_config=execution_config) -# return post_processing_fn(results)[0] -# -# jax.config.update("jax_enable_x64", True) -# print(f(jax.numpy.array(1.2))) -# print(jax.grad(f)(jax.numpy.array(1.2))) - -# 1/2, 1/3, 1/6 - -# 1/2, 1/3, 1/6 -# -# 1/2+1/3-2/6 -# -# 2/6-3/6 -# -# -# gellMann_8_coeffs = np.array([1/np.sqrt(3), -1, 1/np.sqrt(3), 1, 1/np.sqrt(3), -1]) / 2 -# gellMann_8_obs = [qml.GellMann(0, i) for i in [1, 2, 4, 5, 6, 7]] -# H1 = qml.Hamiltonian(gellMann_8_coeffs, gellMann_8_obs).matrix() -# -# G8 = qml.GellMann.compute_matrix(8) -# Had = qml.THadamard.compute_matrix() -# aHad = np.conj(Had).T -# -# H2 = np.round(aHad@G8@Had, 5) -# -# -# # print(np.allclose(H1, H2)) -# # print(H1) -# # print(H2) -# # print(np.round(H2*np.sqrt(3), 5)) -# obs = aHad@G8@Had -# -# diag_gates = qml.THermitian(obs, 0).diagonalizing_gates()#[0].matrix() -# print(len(diag_gates)) -# diag_gates = diag_gates[0].matrix() -# -# print(np.round((np.conj(diag_gates).T)@obs@diag_gates, 5)) - - -def setup_state(nr_wires): - """Sets up a basic state used for testing.""" - setup_unitary = np.array( - [ - [1 / np.sqrt(2), 1 / np.sqrt(3), 1 / np.sqrt(6)], - [np.sqrt(2 / 29), np.sqrt(3 / 29), -2 * np.sqrt(6 / 29)], - [-5 / np.sqrt(58), 7 / np.sqrt(87), 1 / np.sqrt(174)], - ] - ).T - qml.QutritUnitary(setup_unitary, wires=0) - qml.QutritUnitary(setup_unitary, wires=1) - if nr_wires == 3: - qml.TAdd(wires=(0, 2)) - - -dev = qml.device( - "default.qutrit.mixed", - wires=2, - damping_measurement_gammas=(0.2, 0.1, 0.4), - trit_flip_measurement_probs=(0.1, 0.2, 0.5), - ) -# Create matricies for the observables with diagonalizing matrix :math:`THadamard^\dag` -inv_sqrt_3_i = 1j / np.sqrt(3) -non_commuting_obs_one = np.array( - [ - [0, -1 + inv_sqrt_3_i, -1 - inv_sqrt_3_i], - [-1 - inv_sqrt_3_i, 0, -1 + inv_sqrt_3_i], - [-1 + inv_sqrt_3_i, -1 - inv_sqrt_3_i, 0], - ] -) -non_commuting_obs_one /= 2 - -@qml.qnode(dev) -def circuit(): - setup_state(2) - - qml.THadamard(wires=0) - qml.THadamard(wires=1, subspace=(0, 1)) - - return qml.expval(qml.THermitian(non_commuting_obs_one, 0)) - - - - - -print(my_test()) - - - - - - diff --git a/my_testing.py b/my_testing.py deleted file mode 100644 index 611cf1e626e..00000000000 --- a/my_testing.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np -import pennylane as qml -# -# U = np.array([[1/np.sqrt(2), 1/np.sqrt(3), 1/np.sqrt(6)], [np.sqrt(2/29), np.sqrt(3/29), -2 * np.sqrt(6/29)], [-5/np.sqrt(58), 7/np.sqrt(87), 1/np.sqrt(174)]]) -# -# -# print(np.linalg.norm(U[1])) - -# inv_sqrt_3 = 1 / np.sqrt(3) -# gellMann_8_coeffs = np.array([inv_sqrt_3, -1, inv_sqrt_3, 1, inv_sqrt_3, -1]) / 2 -# gellMann_8_obs = [qml.GellMann(0, i) for i in [1, 2, 4, 5, 6, 7]] - -# obs = qml.expval(qml.Hamiltonian(gellMann_8_coeffs, gellMann_8_obs)) -# -# obs.matrix() -# -# ham = qml.Hamiltonian() - -# from scipy.stats import unitary_group -# X = qml.PauliX.compute_matrix() -# Y = qml.PauliX.compute_matrix() -# Z = qml.PauliX.compute_matrix() -# I = np.eye(2) - - - -#gellMann_8_obs = [qml.Hermitian(sum([np.random.rand()*m for m in [X, Y, Z, I]]), wires=0)] -# for obs in gellMann_8_obs: -# diagm = obs.diagonalizing_gates()[0].matrix() -# diagm_adj = np.conj(diagm).T -# obsm = obs.matrix() -# -# print(np.round(diagm@obsm@diagm_adj, 5), "\n") -# print(np.round(diagm_adj @ diagm, 5), "\n===============================================================\n") -had = qml.THadamard.compute_matrix() -ahad = np.conj(had).T -G8 = qml.GellMann.compute_matrix(8) -G3 = qml.GellMann.compute_matrix(3) - - -# print(np.round(had@G8@ahad, 5)) -# -# print() -# print(np.round(had@G3@ahad, 5)) - -inv_sqrt_3 = 1 / np.sqrt(3) -inv_sqrt_3_i = inv_sqrt_3 * 1j -# -gellMann_3_equivalent = ( - np.array( - [[0, 1+inv_sqrt_3_i, 1-inv_sqrt_3_i], - [1-inv_sqrt_3_i, 0, 1 + inv_sqrt_3_i], - [1+inv_sqrt_3_i, 1-inv_sqrt_3_i, 0]] - ) - / 2 - ) -gellMann_8_equivalent = ( - np.array( - [[0, (inv_sqrt_3 - 1j), (inv_sqrt_3 + 1j)], - [inv_sqrt_3 + 1j, 0, inv_sqrt_3 - 1j], - [inv_sqrt_3 - 1j, inv_sqrt_3 + 1j, 0]] - ) - / 2 - ) - -dg = qml.THermitian(gellMann_8_equivalent, 0).diagonalizing_gates()[0].matrix() -# dga = np.conj(dg).T -# print(np.round(dg@gellMann_8_equivalent@dga, 5)) -# -# print(np.abs((dg@had@(np.array([1/2,1/3,1/6])**(1/2))))**2) -#print(qml.GellMann(0, 1).diagonalizing_gates()[0].matrix()) - -# print(np.round(dg@had, 4)) -obs = np.diag([1, 2, 3]) -print(np.round(had@obs@ahad, 4)) - -obs = np.diag([-2, -1, 1]) - -non_commuting_obs_two = np.array( - [ - [-2/3, -2/3 + inv_sqrt_3_i, -2/3 - inv_sqrt_3_i], - [-2/3 - inv_sqrt_3_i, -2/3, -2/3 + inv_sqrt_3_i], - [-2/3 + inv_sqrt_3_i, -2/3 - inv_sqrt_3_i, -2/3], - ] - ) - -print(np.round(had@obs@ahad, 4)) - -print(np.allclose(had@obs@ahad, non_commuting_obs_two)) -# print(np.allclose(had@G8@ahad, gellMann_8_equivalent)) -# #print(had@G8@ahad) -# print(np.allclose(ahad, dg)) -import jax -print(jax.numpy.array([jax.numpy.nan, 1., 2.])) - - - diff --git a/pennylane/devices/jittable_mixed.py b/pennylane/devices/jittable_mixed.py deleted file mode 100644 index 0e2d0fbfcf0..00000000000 --- a/pennylane/devices/jittable_mixed.py +++ /dev/null @@ -1,797 +0,0 @@ -# Copyright 2018-2021 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -r""" -The default.mixed device is PennyLane's standard qubit simulator for mixed-state computations. - -It implements the necessary :class:`~pennylane.Device` methods as well as some built-in -qubit :doc:`operations `, providing a simple mixed-state simulation of -qubit-based quantum circuits. -""" - -import functools -import itertools -import logging -from collections import defaultdict -from string import ascii_letters as ABC - -import numpy as np - -import pennylane as qml -import pennylane.math as qnp -from pennylane import BasisState, DeviceError, QubitDensityMatrix, QubitDevice, Snapshot, StatePrep -from pennylane.logging import debug_logger, debug_logger_init -from pennylane.measurements import ( - CountsMP, - DensityMatrixMP, - ExpectationMP, - MutualInfoMP, - ProbabilityMP, - PurityMP, - SampleMP, - StateMP, - VarianceMP, - VnEntropyMP, -) -from pennylane.operation import Channel -from pennylane.ops.qubit.attributes import diagonal_in_z_basis -from pennylane.wires import Wires - -from .._version import __version__ - -logger = logging.getLogger(__name__) -logger.addHandler(logging.NullHandler()) - -ABC_ARRAY = np.array(list(ABC)) -tolerance = 1e-10 - - -class DefaultMixed(QubitDevice): - """Default qubit device for performing mixed-state computations in PennyLane. - - .. warning:: - - The API of ``DefaultMixed`` will be updated soon to follow a new device interface described - in :class:`pennylane.devices.Device`. - - This change will not alter device behaviour for most workflows, but may have implications for - plugin developers and users who directly interact with device methods. Please consult - :class:`pennylane.devices.Device` and the implementation in - :class:`pennylane.devices.DefaultQubit` for more information on what the new - interface will look like and be prepared to make updates in a coming release. If you have any - feedback on these changes, please create an - `issue `_ or post in our - `discussion forum `_. - - Args: - wires (int, Iterable[Number, str]): Number of subsystems represented by the device, - or iterable that contains unique labels for the subsystems as numbers - (i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``). - shots (None, int): Number of times the circuit should be evaluated (or sampled) to estimate - the expectation values. Defaults to ``None`` if not specified, which means that - outputs are computed exactly. - readout_prob (None, int, float): Probability for adding readout error to the measurement - outcomes of observables. Defaults to ``None`` if not specified, which means that the outcomes are - without any readout error. - """ - - name = "Default mixed-state qubit PennyLane plugin" - short_name = "default.mixed" - pennylane_requires = __version__ - version = __version__ - author = "Xanadu Inc." - - operations = { - "Identity", - "Snapshot", - "BasisState", - "QubitStateVector", - "StatePrep", - "QubitDensityMatrix", - "QubitUnitary", - "ControlledQubitUnitary", - "BlockEncode", - "MultiControlledX", - "DiagonalQubitUnitary", - "SpecialUnitary", - "PauliX", - "PauliY", - "PauliZ", - "MultiRZ", - "Hadamard", - "S", - "T", - "SX", - "CNOT", - "SWAP", - "ISWAP", - "CSWAP", - "Toffoli", - "CCZ", - "CY", - "CZ", - "CH", - "PhaseShift", - "PCPhase", - "ControlledPhaseShift", - "CPhaseShift00", - "CPhaseShift01", - "CPhaseShift10", - "RX", - "RY", - "RZ", - "Rot", - "CRX", - "CRY", - "CRZ", - "CRot", - "AmplitudeDamping", - "GeneralizedAmplitudeDamping", - "PhaseDamping", - "DepolarizingChannel", - "BitFlip", - "PhaseFlip", - "PauliError", - "ResetError", - "QubitChannel", - "SingleExcitation", - "SingleExcitationPlus", - "SingleExcitationMinus", - "DoubleExcitation", - "DoubleExcitationPlus", - "DoubleExcitationMinus", - "QubitCarry", - "QubitSum", - "OrbitalRotation", - "FermionicSWAP", - "QFT", - "ThermalRelaxationError", - "ECR", - "ParametrizedEvolution", - "GlobalPhase", - } - - _reshape = staticmethod(qnp.reshape) - _flatten = staticmethod(qnp.flatten) - _transpose = staticmethod(qnp.transpose) - # Allow for the `axis` keyword argument for integration with broadcasting-enabling - # code in QubitDevice. However, it is not used as DefaultMixed does not support broadcasting - # pylint: disable=unnecessary-lambda - _gather = staticmethod(lambda *args, axis=0, **kwargs: qnp.gather(*args, **kwargs)) - _dot = staticmethod(qnp.dot) - - measurement_map = defaultdict(lambda: "") - measurement_map[PurityMP] = "purity" - - @staticmethod - def _reduce_sum(array, axes): - return qnp.sum(array, tuple(axes)) - - @staticmethod - def _asarray(array, dtype=None): - # Support float - if not hasattr(array, "__len__"): - return np.asarray(array, dtype=dtype) - - res = qnp.cast(array, dtype=dtype) - return res - - @debug_logger_init - def __init__( - self, - wires, - *, - r_dtype=np.float64, - c_dtype=np.complex128, - shots=None, - analytic=None, - readout_prob=None, - ): - if isinstance(wires, int) and wires > 23: - raise ValueError( - "This device does not currently support computations on more than 23 wires" - ) - - self.readout_err = readout_prob - # Check that the readout error probability, if entered, is either integer or float in [0,1] - if self.readout_err is not None: - if not isinstance(self.readout_err, float) and not isinstance(self.readout_err, int): - raise TypeError( - "The readout error probability should be an integer or a floating-point number in [0,1]." - ) - if self.readout_err < 0 or self.readout_err > 1: - raise ValueError("The readout error probability should be in the range [0,1].") - - # call QubitDevice init - super().__init__(wires, shots, r_dtype=r_dtype, c_dtype=c_dtype, analytic=analytic) - self._debugger = None - - # Create the initial state. - self._state = self._create_basis_state(0) - self._pre_rotated_state = self._state - self.measured_wires = [] - """List: during execution, stores the list of wires on which measurements are acted for - applying the readout error to them when readout_prob is non-zero.""" - - def _create_basis_state(self, index): - """Return the density matrix representing a computational basis state over all wires. - - Args: - index (int): integer representing the computational basis state. - - Returns: - array[complex]: complex array of shape ``[2] * (2 * num_wires)`` - representing the density matrix of the basis state. - """ - rho = qnp.zeros((2**self.num_wires, 2**self.num_wires), dtype=self.C_DTYPE) - rho[index, index] = 1 - return qnp.reshape(rho, [2] * (2 * self.num_wires)) - - @classmethod - def capabilities(cls): - capabilities = super().capabilities().copy() - capabilities.update( - returns_state=True, - passthru_devices={ - "autograd": "default.mixed", - "tf": "default.mixed", - "torch": "default.mixed", - "jax": "default.mixed", - }, - ) - return capabilities - - @property - def state(self): - """Returns the state density matrix of the circuit prior to measurement""" - dim = 2**self.num_wires - # User obtains state as a matrix - return qnp.reshape(self._pre_rotated_state, (dim, dim)) - - @debug_logger - def density_matrix(self, wires): - """Returns the reduced density matrix over the given wires. - - Args: - wires (Wires): wires of the reduced system - - Returns: - array[complex]: complex array of shape ``(2 ** len(wires), 2 ** len(wires))`` - representing the reduced density matrix of the state prior to measurement. - """ - state = getattr(self, "state", None) - wires = self.map_wires(wires) - return qml.math.reduce_dm(state, indices=wires, c_dtype=self.C_DTYPE) - - @debug_logger - def purity(self, mp, **kwargs): # pylint: disable=unused-argument - """Returns the purity of the final state""" - state = getattr(self, "state", None) - wires = self.map_wires(mp.wires) - return qml.math.purity(state, indices=wires, c_dtype=self.C_DTYPE) - - @debug_logger - def reset(self): - """Resets the device""" - super().reset() - - self._state = self._create_basis_state(0) - self._pre_rotated_state = self._state - - @debug_logger - def analytic_probability(self, wires=None): - if self._state is None: - return None - - # convert rho from tensor to matrix - rho = qnp.reshape(self._state, (2**self.num_wires, 2**self.num_wires)) - - # probs are diagonal elements - probs = self.marginal_prob(qnp.diagonal(rho), wires) - - # take the real part so probabilities are not shown as complex numbers - probs = qnp.real(probs) - return qnp.where(probs < 0, -probs, probs) - - def _get_kraus(self, operation): # pylint: disable=no-self-use - """Return the Kraus operators representing the operation. - - Args: - operation (.Operation): a PennyLane operation - - Returns: - list[array[complex]]: Returns a list of 2D matrices representing the Kraus operators. If - the operation is unitary, returns a single Kraus operator. In the case of a diagonal - unitary, returns a 1D array representing the matrix diagonal. - """ - if operation in diagonal_in_z_basis: - return operation.eigvals() - - if isinstance(operation, Channel): - return operation.kraus_matrices() - - return [operation.matrix()] - - def _apply_channel(self, kraus, wires): - r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the - quantum state. For a unitary gate, there is a single Kraus operator. - - Args: - kraus (list[array]): Kraus operators - wires (Wires): target wires - """ - channel_wires = self.map_wires(wires) - rho_dim = 2 * self.num_wires - num_ch_wires = len(channel_wires) - - # Computes K^\dagger, needed for the transformation K \rho K^\dagger - kraus_dagger = [qnp.conj(qnp.transpose(k)) for k in kraus] - - kraus = qnp.stack(kraus) - kraus_dagger = qnp.stack(kraus_dagger) - - # Shape kraus operators - kraus_shape = [len(kraus)] + [2] * num_ch_wires * 2 - kraus = qnp.cast(qnp.reshape(kraus, kraus_shape), dtype=self.C_DTYPE) - kraus_dagger = qnp.cast(qnp.reshape(kraus_dagger, kraus_shape), dtype=self.C_DTYPE) - - # Tensor indices of the state. For each qubit, need an index for rows *and* columns - state_indices = ABC[:rho_dim] - - # row indices of the quantum state affected by this operation - row_wires_list = channel_wires.tolist() - row_indices = "".join(ABC_ARRAY[row_wires_list].tolist()) - - # column indices are shifted by the number of wires - col_wires_list = [w + self.num_wires for w in row_wires_list] - col_indices = "".join(ABC_ARRAY[col_wires_list].tolist()) - - # indices in einsum must be replaced with new ones - new_row_indices = ABC[rho_dim : rho_dim + num_ch_wires] - new_col_indices = ABC[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires] - - # index for summation over Kraus operators - kraus_index = ABC[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1] - - # new state indices replace row and column indices with new ones - new_state_indices = functools.reduce( - lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), - zip(col_indices + row_indices, new_col_indices + new_row_indices), - state_indices, - ) - - # index mapping for einsum, e.g., 'iga,abcdef,idh->gbchef' - einsum_indices = ( - f"{kraus_index}{new_row_indices}{row_indices}, {state_indices}," - f"{kraus_index}{col_indices}{new_col_indices}->{new_state_indices}" - ) - - self._state = qnp.einsum(einsum_indices, kraus, self._state, kraus_dagger) - - def _apply_channel_tensordot(self, kraus, wires): - r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the - quantum state. For a unitary gate, there is a single Kraus operator. - - Args: - kraus (list[array]): Kraus operators - wires (Wires): target wires - """ - channel_wires = self.map_wires(wires) - num_ch_wires = len(channel_wires) - - # Shape kraus operators and cast them to complex data type - kraus_shape = [2] * (num_ch_wires * 2) - kraus = [qnp.cast(qnp.reshape(k, kraus_shape), dtype=self.C_DTYPE) for k in kraus] - - # row indices of the quantum state affected by this operation - row_wires_list = channel_wires.tolist() - # column indices are shifted by the number of wires - col_wires_list = [w + self.num_wires for w in row_wires_list] - - channel_col_ids = list(range(num_ch_wires, 2 * num_ch_wires)) - axes_left = [channel_col_ids, row_wires_list] - # Use column indices instead or rows to incorporate transposition of K^\dagger - axes_right = [col_wires_list, channel_col_ids] - - # Apply the Kraus operators, and sum over all Kraus operators afterwards - def _conjugate_state_with(k): - """Perform the double tensor product k @ self._state @ k.conj(). - The `axes_left` and `axes_right` arguments are taken from the ambient variable space - and `axes_right` is assumed to incorporate the tensor product and the transposition - of k.conj() simultaneously.""" - return qnp.tensordot(qnp.tensordot(k, self._state, axes_left), qnp.conj(k), axes_right) - - if len(kraus) == 1: - _state = _conjugate_state_with(kraus[0]) - else: - _state = qnp.sum(qnp.stack([_conjugate_state_with(k) for k in kraus]), axis=0) - - # Permute the affected axes to their destination places. - # The row indices of the kraus operators are moved from the beginning to the original - # target row locations, the column indices from the end to the target column locations - source_left = list(range(num_ch_wires)) - dest_left = row_wires_list - source_right = list(range(-num_ch_wires, 0)) - dest_right = col_wires_list - self._state = qnp.moveaxis(_state, source_left + source_right, dest_left + dest_right) - - def _apply_diagonal_unitary(self, eigvals, wires): - r"""Apply a diagonal unitary gate specified by a list of eigenvalues. This method uses - the fact that the unitary is diagonal for a more efficient implementation. - - Args: - eigvals (array): eigenvalues (phases) of the diagonal unitary - wires (Wires): target wires - """ - - channel_wires = self.map_wires(wires) - - eigvals = qnp.stack(eigvals) - - # reshape vectors - eigvals = qnp.cast(qnp.reshape(eigvals, [2] * len(channel_wires)), dtype=self.C_DTYPE) - - # Tensor indices of the state. For each qubit, need an index for rows *and* columns - state_indices = ABC[: 2 * self.num_wires] - - # row indices of the quantum state affected by this operation - row_wires_list = channel_wires.tolist() - row_indices = "".join(ABC_ARRAY[row_wires_list].tolist()) - - # column indices are shifted by the number of wires - col_wires_list = [w + self.num_wires for w in row_wires_list] - col_indices = "".join(ABC_ARRAY[col_wires_list].tolist()) - - einsum_indices = f"{row_indices},{state_indices},{col_indices}->{state_indices}" - - self._state = qnp.einsum(einsum_indices, eigvals, self._state, qnp.conj(eigvals)) - - def _apply_basis_state(self, state, wires): - """Initialize the device in a specified computational basis state. - - Args: - state (array[int]): computational basis state of shape ``(wires,)`` - consisting of 0s and 1s. - wires (Wires): wires that the provided computational state should be initialized on - """ - # translate to wire labels used by device - device_wires = self.map_wires(wires) - - # length of basis state parameter - n_basis_state = len(state) - - if not set(state).issubset({0, 1}): - raise ValueError("BasisState parameter must consist of 0 or 1 integers.") - - if n_basis_state != len(device_wires): - raise ValueError("BasisState parameter and wires must be of equal length.") - - # get computational basis state number - basis_states = 2 ** (self.num_wires - 1 - device_wires.toarray()) - num = int(qnp.dot(state, basis_states)) - - self._state = self._create_basis_state(num) - - def _apply_state_vector(self, state, device_wires): - """Initialize the internal state in a specified pure state. - - Args: - state (array[complex]): normalized input state of length - ``2**len(wires)`` - device_wires (Wires): wires that get initialized in the state - """ - - # translate to wire labels used by device - device_wires = self.map_wires(device_wires) - - state = qnp.asarray(state, dtype=self.C_DTYPE) - n_state_vector = state.shape[0] - - if state.ndim != 1 or n_state_vector != 2 ** len(device_wires): - raise ValueError("State vector must be of length 2**wires.") - - if not qnp.allclose(qnp.linalg.norm(state, ord=2), 1.0, atol=tolerance): - raise ValueError("Sum of amplitudes-squared does not equal one.") - - if len(device_wires) == self.num_wires and sorted(device_wires.labels) == list( - device_wires.labels - ): - # Initialize the entire wires with the state - rho = qnp.outer(state, qnp.conj(state)) - self._state = qnp.reshape(rho, [2] * 2 * self.num_wires) - - else: - # generate basis states on subset of qubits via the cartesian product - basis_states = qnp.asarray( - list(itertools.product([0, 1], repeat=len(device_wires))), dtype=int - ) - - # get basis states to alter on full set of qubits - unravelled_indices = qnp.zeros((2 ** len(device_wires), self.num_wires), dtype=int) - unravelled_indices[:, device_wires] = basis_states - - # get indices for which the state is changed to input state vector elements - ravelled_indices = qnp.ravel_multi_index(unravelled_indices.T, [2] * self.num_wires) - - state = qnp.scatter(ravelled_indices, state, [2**self.num_wires]) - rho = qnp.outer(state, qnp.conj(state)) - rho = qnp.reshape(rho, [2] * 2 * self.num_wires) - self._state = qnp.asarray(rho, dtype=self.C_DTYPE) - - def _apply_density_matrix(self, state, device_wires): - r"""Initialize the internal state in a specified mixed state. - If not all the wires are specified in the full state :math:`\rho`, remaining subsystem is filled by - `\mathrm{tr}_in(\rho)`, which results in the full system state :math:`\mathrm{tr}_{in}(\rho) \otimes \rho_{in}`, - where :math:`\rho_{in}` is the argument `state` of this function and :math:`\mathrm{tr}_{in}` is a partial - trace over the subsystem to be replaced by this operation. - - Args: - state (array[complex]): density matrix of length - ``(2**len(wires), 2**len(wires))`` - device_wires (Wires): wires that get initialized in the state - """ - - # translate to wire labels used by device - device_wires = self.map_wires(device_wires) - - state = qnp.asarray(state, dtype=self.C_DTYPE) - state = qnp.reshape(state, (-1,)) - - state_dim = 2 ** len(device_wires) - dm_dim = state_dim**2 - if dm_dim != state.shape[0]: - raise ValueError("Density matrix must be of length (2**wires, 2**wires)") - - if not qml.math.is_abstract(state) and not qnp.allclose( - qnp.trace(qnp.reshape(state, (state_dim, state_dim))), 1.0, atol=tolerance - ): - raise ValueError("Trace of density matrix is not equal one.") - - if len(device_wires) == self.num_wires and sorted(device_wires.labels) == list( - device_wires.labels - ): - # Initialize the entire wires with the state - - self._state = qnp.reshape(state, [2] * 2 * self.num_wires) - self._pre_rotated_state = self._state - - else: - # Initialize tr_in(ρ) ⊗ ρ_in with transposed wires where ρ is the density matrix before this operation. - - complement_wires = list(sorted(list(set(range(self.num_wires)) - set(device_wires)))) - sigma = self.density_matrix(Wires(complement_wires)) - rho = qnp.kron(sigma, state.reshape(state_dim, state_dim)) - rho = rho.reshape([2] * 2 * self.num_wires) - - # Construct transposition axis to revert back to the original wire order - left_axes = [] - right_axes = [] - complement_wires_count = len(complement_wires) - for i in range(self.num_wires): - if i in device_wires: - index = device_wires.index(i) - left_axes.append(complement_wires_count + index) - right_axes.append(complement_wires_count + index + self.num_wires) - elif i in complement_wires: - index = complement_wires.index(i) - left_axes.append(index) - right_axes.append(index + self.num_wires) - transpose_axes = left_axes + right_axes - rho = qnp.transpose(rho, axes=transpose_axes) - assert qml.math.is_abstract(rho) or qnp.allclose( - qnp.trace(qnp.reshape(rho, (2**self.num_wires, 2**self.num_wires))), - 1.0, - atol=tolerance, - ) - - self._state = qnp.asarray(rho, dtype=self.C_DTYPE) - self._pre_rotated_state = self._state - - def _snapshot_measurements(self, density_matrix, measurement): - """Perform state-based snapshot measurement""" - meas_wires = self.wires if not measurement.wires else measurement.wires - - pre_rotated_state = self._state - if isinstance(measurement, (ProbabilityMP, ExpectationMP, VarianceMP)): - for diag_gate in measurement.diagonalizing_gates(): - self._apply_operation(diag_gate) - - if isinstance(measurement, (StateMP, DensityMatrixMP)): - map_wires = self.map_wires(meas_wires) - snap_result = qml.math.reduce_dm( - density_matrix, indices=map_wires, c_dtype=self.C_DTYPE - ) - - elif isinstance(measurement, PurityMP): - map_wires = self.map_wires(meas_wires) - snap_result = qml.math.purity(density_matrix, indices=map_wires, c_dtype=self.C_DTYPE) - - elif isinstance(measurement, ProbabilityMP): - snap_result = self.analytic_probability(wires=meas_wires) - - elif isinstance(measurement, ExpectationMP): - eigvals = self._asarray(measurement.obs.eigvals(), dtype=self.R_DTYPE) - probs = self.analytic_probability(wires=meas_wires) - snap_result = self._dot(probs, eigvals) - - elif isinstance(measurement, VarianceMP): - eigvals = self._asarray(measurement.obs.eigvals(), dtype=self.R_DTYPE) - probs = self.analytic_probability(wires=meas_wires) - snap_result = self._dot(probs, (eigvals**2)) - self._dot(probs, eigvals) ** 2 - - elif isinstance(measurement, VnEntropyMP): - base = measurement.log_base - map_wires = self.map_wires(meas_wires) - snap_result = qml.math.vn_entropy( - density_matrix, indices=map_wires, c_dtype=self.C_DTYPE, base=base - ) - - elif isinstance(measurement, MutualInfoMP): - base = measurement.log_base - wires0, wires1 = list(map(self.map_wires, measurement.raw_wires)) - snap_result = qml.math.mutual_info( - density_matrix, - indices0=wires0, - indices1=wires1, - c_dtype=self.C_DTYPE, - base=base, - ) - - else: - raise DeviceError( - f"Snapshots of {type(measurement)} are not yet supported on default.mixed" - ) - - self._state = pre_rotated_state - self._pre_rotated_state = self._state - - return snap_result - - def _apply_snapshot(self, operation): - """Applies the snapshot operation""" - measurement = operation.hyperparameters["measurement"] - - if self._debugger and self._debugger.active: - dim = 2**self.num_wires - density_matrix = qnp.reshape(self._state, (dim, dim)) - - snapshot_result = self._snapshot_measurements(density_matrix, measurement) - - if operation.tag: - self._debugger.snapshots[operation.tag] = snapshot_result - else: - self._debugger.snapshots[len(self._debugger.snapshots)] = snapshot_result - - def _apply_operation(self, operation): - """Applies operations to the internal device state. - - Args: - operation (.Operation): operation to apply on the device - """ - wires = operation.wires - if operation.name == "Identity": - return - - if isinstance(operation, StatePrep): - self._apply_state_vector(operation.parameters[0], wires) - return - - if isinstance(operation, BasisState): - self._apply_basis_state(operation.parameters[0], wires) - return - - if isinstance(operation, QubitDensityMatrix): - self._apply_density_matrix(operation.parameters[0], wires) - return - - if isinstance(operation, Snapshot): - self._apply_snapshot(operation) - return - - matrices = self._get_kraus(operation) - - if operation in diagonal_in_z_basis: - self._apply_diagonal_unitary(matrices, wires) - else: - num_op_wires = len(wires) - interface = qml.math.get_interface(self._state, *matrices) - # Use tensordot for Autograd and Numpy if there are more than 2 wires - # Use tensordot in any case for more than 7 wires, as einsum does not support this case - if (num_op_wires > 2 and interface in {"autograd", "numpy"}) or num_op_wires > 7: - self._apply_channel_tensordot(matrices, wires) - else: - self._apply_channel(matrices, wires) - - # pylint: disable=arguments-differ - - @debug_logger - def execute(self, circuit, **kwargs): - """Execute a queue of quantum operations on the device and then - measure the given observables. - - Applies a readout error to the measurement outcomes of any observable if - readout_prob is non-zero. This is done by finding the list of measured wires on which - BitFlip channels are applied in the :meth:`apply`. - - For plugin developers: instead of overwriting this, consider - implementing a suitable subset of - - * :meth:`apply` - - * :meth:`~.generate_samples` - - * :meth:`~.probability` - - Additional keyword arguments may be passed to this method - that can be utilised by :meth:`apply`. An example would be passing - the ``QNode`` hash that can be used later for parametric compilation. - - Args: - circuit (QuantumTape): circuit to execute on the device - - Raises: - QuantumFunctionError: if the value of :attr:`~.Observable.return_type` is not supported - - Returns: - array[float]: measured value(s) - """ - if self.readout_err: - wires_list = [] - for m in circuit.measurements: - if isinstance(m, StateMP): - # State: This returns pre-rotated state, so no readout error. - # Assumed to only be allowed if it's the only measurement. - self.measured_wires = [] - return super().execute(circuit, **kwargs) - if isinstance(m, (SampleMP, CountsMP)) and m.wires in ( - qml.wires.Wires([]), - self.wires, - ): - # Sample, Counts: Readout error applied to all device wires when wires - # not specified or all wires specified. - self.measured_wires = self.wires - return super().execute(circuit, **kwargs) - if isinstance(m, (VnEntropyMP, MutualInfoMP)): - # VnEntropy, MutualInfo: Computed for the state prior to measurement. So, readout - # error need not be applied on the corresponding device wires. - continue - wires_list.append(m.wires) - self.measured_wires = qml.wires.Wires.all_wires(wires_list) - return super().execute(circuit, **kwargs) - - @debug_logger - def apply(self, operations, rotations=None, **kwargs): - rotations = rotations or [] - - # apply the circuit operations - for i, operation in enumerate(operations): - if i > 0 and isinstance(operation, (StatePrep, BasisState)): - raise DeviceError( - f"Operation {operation.name} cannot be used after other Operations have already been applied " - f"on a {self.short_name} device." - ) - - for operation in operations: - self._apply_operation(operation) - - # store the pre-rotated state - self._pre_rotated_state = self._state - - # apply the circuit rotations - for operation in rotations: - self._apply_operation(operation) - - if self.readout_err: - for k in self.measured_wires: - bit_flip = qml.BitFlip(self.readout_err, wires=k) - self._apply_operation(bit_flip) diff --git a/pennylane/devices/jittable_qutrit_mixed.py b/pennylane/devices/jittable_qutrit_mixed.py deleted file mode 100644 index 0ade57c0c73..00000000000 --- a/pennylane/devices/jittable_qutrit_mixed.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright 2018-2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""The default.qutrit.mixed device is PennyLane's standard qutrit simulator for mixed-state -computations.""" -import logging -from dataclasses import replace -from typing import Callable, Optional, Sequence, Tuple, Union - -import numpy as np - -import pennylane as qml -from pennylane.logging import debug_logger, debug_logger_init -from pennylane.ops import _qutrit__channel__ops__ as channels -from pennylane.tape import QuantumTape -from pennylane.transforms.core import TransformProgram -from pennylane.typing import Result, ResultBatch - -from . import Device -from .default_qutrit import DefaultQutrit -from .execution_config import DefaultExecutionConfig, ExecutionConfig -from .modifiers import simulator_tracking, single_tape_support -from .preprocess import ( - decompose, - no_sampling, - validate_device_wires, - validate_measurements, - validate_observables, -) -from .qutrit_mixed.simulate import simulate - -logger = logging.getLogger(__name__) -logger.addHandler(logging.NullHandler()) - -Result_or_ResultBatch = Union[Result, ResultBatch] -QuantumTapeBatch = Sequence[QuantumTape] -QuantumTape_or_Batch = Union[QuantumTape, QuantumTapeBatch] - -# always a function from a resultbatch to either a result or a result batch -PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch] - -observables = { - "THermitian", - "GellMann", -} - - -def observable_stopping_condition(obs: qml.operation.Operator) -> bool: - """Specifies whether an observable is accepted by DefaultQutritMixed.""" - if isinstance(obs, qml.operation.Tensor): - return all(observable_stopping_condition(observable) for observable in obs.obs) - if obs.name in {"Prod", "Sum"}: - return all(observable_stopping_condition(observable) for observable in obs.operands) - if obs.name in {"LinearCombination", "Hamiltonian"}: - return all(observable_stopping_condition(observable) for observable in obs.terms()[1]) - if obs.name == "SProd": - return observable_stopping_condition(obs.base) - - return obs.name in observables - - -def stopping_condition(op: qml.operation.Operator) -> bool: - """Specify whether an Operator object is supported by the device.""" - expected_set = DefaultQutrit.operations | {"Snapshot"} | channels - return op.name in expected_set - - -def stopping_condition_shots(op: qml.operation.Operator) -> bool: - """Specify whether an Operator object is supported by the device with shots.""" - return stopping_condition(op) - - -def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool: - """Specifies whether a measurement is accepted when sampling.""" - return isinstance(m, qml.measurements.SampleMeasurement) - - -@simulator_tracking -@single_tape_support -class DefaultQutritMixed(Device): - """A PennyLane Python-based device for mixed-state qutrit simulation. - - Args: - wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that - contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings - (``['ancilla', 'q1', 'q2']``). Default ``None`` if not specified. - shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots - to use in executions involving this device. - seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A - seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``, or - a request to seed from numpy's global random number generator. - The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` - will pull a seed from the OS entropy. - If a ``jax.random.PRNGKey`` is passed as the seed, a JAX-specific sampling function using - ``jax.random.choice`` and the ``PRNGKey`` will be used for sampling rather than - ``numpy.random.default_rng``. - - **Example:** - - .. code-block:: python - - n_wires = 5 - num_qscripts = 5 - qscripts = [] - for i in range(num_qscripts): - unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() - op = qml.QutritUnitary(unitary, wires=range(n_wires)) - qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) - qscripts.append(qs) - - >>> dev = DefaultQutritMixed() - >>> program, execution_config = dev.preprocess() - >>> new_batch, post_processing_fn = program(qscripts) - >>> results = dev.execute(new_batch, execution_config=execution_config) - >>> post_processing_fn(results) - [0.08015701503959313, - 0.04521414211599359, - -0.0215232130089687, - 0.062120285032425865, - -0.0635052317625] - - This device currently supports backpropagation derivatives: - - >>> from pennylane.devices import ExecutionConfig - >>> dev.supports_derivatives(ExecutionConfig(gradient_method="backprop")) - True - - For example, we can use jax to jit computing the derivative: - - .. code-block:: python - - import jax - - @jax.jit - def f(x): - qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) - program, execution_config = dev.preprocess() - new_batch, post_processing_fn = program([qs]) - results = dev.execute(new_batch, execution_config=execution_config) - return post_processing_fn(results)[0] - - >>> f(jax.numpy.array(1.2)) - DeviceArray(0.36235774, dtype=float32) - >>> jax.grad(f)(jax.numpy.array(1.2)) - DeviceArray(-0.93203914, dtype=float32, weak_type=True) - - .. details:: - :title: Tracking - - ``DefaultQutritMixed`` tracks: - - * ``executions``: the number of unique circuits that would be required on quantum hardware - * ``shots``: the number of shots - * ``resources``: the :class:`~.resource.Resources` for the executed circuit. - * ``simulations``: the number of simulations performed. One simulation can cover multiple QPU executions, such as for non-commuting measurements and batched parameters. - * ``batches``: The number of times :meth:`~.execute` is called. - * ``results``: The results of each call of :meth:`~.execute` - - - """ - - _device_options = ("rng", "prng_key") # tuple of string names for all the device options. - - @property - def name(self): - """The name of the device.""" - return "default.qutrit.mixed" - - @debug_logger_init - def __init__( - self, - wires=None, - shots=None, - seed="global", - ) -> None: - super().__init__(wires=wires, shots=shots) - seed = np.random.randint(0, high=10000000) if seed == "global" else seed - if qml.math.get_interface(seed) == "jax": - self._prng_key = seed - self._rng = np.random.default_rng(None) - else: - self._prng_key = None - self._rng = np.random.default_rng(seed) - self._debugger = None - - @debug_logger - def supports_derivatives( - self, - execution_config: Optional[ExecutionConfig] = None, - circuit: Optional[QuantumTape] = None, - ) -> bool: - """Check whether or not derivatives are available for a given configuration and circuit. - - ``DefaultQutritMixed`` supports backpropagation derivatives with analytic results. - - Args: - execution_config (ExecutionConfig): The configuration of the desired derivative calculation. - circuit (QuantumTape): An optional circuit to check derivatives support for. - - Returns: - bool: Whether or not a derivative can be calculated provided the given information. - - """ - if execution_config is None or execution_config.gradient_method in {"backprop", "best"}: - return circuit is None or not circuit.shots - return False - - def _setup_execution_config(self, execution_config: ExecutionConfig) -> ExecutionConfig: - """This is a private helper for ``preprocess`` that sets up the execution config. - - Args: - execution_config (ExecutionConfig): an unprocessed execution config. - - Returns: - ExecutionConfig: a preprocessed execution config. - """ - updated_values = {} - for option in execution_config.device_options: - if option not in self._device_options: - raise qml.DeviceError(f"device option {option} not present on {self}") - - if execution_config.gradient_method == "best": - updated_values["gradient_method"] = "backprop" - updated_values["use_device_gradient"] = False - updated_values["grad_on_execution"] = False - updated_values["device_options"] = dict(execution_config.device_options) # copy - - for option in self._device_options: - if option not in updated_values["device_options"]: - updated_values["device_options"][option] = getattr(self, f"_{option}") - return replace(execution_config, **updated_values) - - @debug_logger - def preprocess( - self, - execution_config: ExecutionConfig = DefaultExecutionConfig, - ) -> Tuple[TransformProgram, ExecutionConfig]: - """This function defines the device transform program to be applied and an updated device - configuration. - - Args: - execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure - describing the parameters needed to fully describe the execution. - - Returns: - TransformProgram, ExecutionConfig: A transform program that when called returns - ``QuantumTape`` objects that the device can natively execute, as well as a postprocessing - function to be called after execution, and a configuration with unset - specifications filled in. - - This device: - - * Supports any qutrit operations that provide a matrix - * Supports any qutrit channel that provides Kraus matrices - - """ - config = self._setup_execution_config(execution_config) - transform_program = TransformProgram() - - transform_program.add_transform(validate_device_wires, self.wires, name=self.name) - transform_program.add_transform( - decompose, - stopping_condition=stopping_condition, - stopping_condition_shots=stopping_condition_shots, - name=self.name, - ) - transform_program.add_transform( - validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name - ) - transform_program.add_transform( - validate_observables, stopping_condition=observable_stopping_condition, name=self.name - ) - - if config.gradient_method == "backprop": - transform_program.add_transform(no_sampling, name="backprop + default.qutrit") - - return transform_program, config - - @debug_logger - def execute( - self, - circuits: QuantumTape_or_Batch, - execution_config: ExecutionConfig = DefaultExecutionConfig, - ) -> Result_or_ResultBatch: - - interface = ( - execution_config.interface - if execution_config.gradient_method in {"best", "backprop", None} - else None - ) - - return tuple( - simulate( - c, - rng=self._rng, - prng_key=self._prng_key, - debugger=self._debugger, - interface=interface, - ) - for c in circuits - ) From 432dc780238ba0290d705c804c213cbe02091d35 Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Fri, 5 Jul 2024 15:03:24 -0700 Subject: [PATCH 08/26] removed makefile to stop CI --- Makefile | 172 +++++++++++++++++++++++++++---------------------------- 1 file changed, 86 insertions(+), 86 deletions(-) diff --git a/Makefile b/Makefile index 5f7186e4549..4e56417edba 100644 --- a/Makefile +++ b/Makefile @@ -1,86 +1,86 @@ -PYTHON3 := $(shell which python3 2>/dev/null) - -PYTHON := python3 -COVERAGE := --cov=pennylane --cov-report term-missing --cov-report=html:coverage_html_report -TESTRUNNER := -m pytest tests --tb=native --no-flaky-report -PLUGIN_TESTRUNNER := -m pytest pennylane/devices/tests --tb=native --no-flaky-report - -.PHONY: help -help: - @echo "Please use \`make ' where is one of" - @echo " install to install PennyLane" - @echo " wheel to build the PennyLane wheel" - @echo " dist to package the source distribution" - @echo " clean to delete all temporary, cache, and build files" - @echo " docs to build the PennyLane documentation" - @echo " clean-docs to delete all built documentation" - @echo " test to run the test suite" - @echo " coverage to generate a coverage report" - @echo " format [check=1] to apply black formatter; use with 'check=1' to check instead of modify (requires black)" - @echo " lint to run pylint on source files" - @echo " lint-test to run pylint on test files" - -.PHONY: install -install: -ifndef PYTHON3 - @echo "To install PennyLane you need to have Python 3 installed" -endif - $(PYTHON) setup.py install - -.PHONY: wheel -wheel: - $(PYTHON) setup.py bdist_wheel - -.PHONY: dist -dist: - $(PYTHON) setup.py sdist - -.PHONY : clean -clean: - rm -rf pennylane/__pycache__ - rm -rf pennylane/optimize/__pycache__ - rm -rf pennylane/expectation/__pycache__ - rm -rf pennylane/ops/__pycache__ - rm -rf pennylane/devices/__pycache__ - rm -rf tests/__pycache__ - rm -rf tests/new_qnode/__pycache__ - rm -rf dist - rm -rf build - rm -rf .coverage coverage_html_report/ - rm -rf tmp - rm -rf *.dat - -docs: - make -C doc html - -.PHONY : clean-docs -clean-docs: - rm -rf doc/code/api - make -C doc clean - -test: - $(PYTHON) $(TESTRUNNER) - $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd - -coverage: - @echo "Generating coverage report..." - $(PYTHON) $(TESTRUNNER) $(COVERAGE) - $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd $(COVERAGE) --cov-append - -.PHONY:format -format: -ifdef check - isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check - black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check -else - isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests - black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests -endif - -.PHONY: lint -lint: - pylint pennylane --rcfile .pylintrc - -.PHONY: lint-test -lint-test: - pylint tests pennylane/devices/tests --rcfile tests/.pylintrc +# PYTHON3 := $(shell which python3 2>/dev/null) +# +# PYTHON := python3 +# COVERAGE := --cov=pennylane --cov-report term-missing --cov-report=html:coverage_html_report +# TESTRUNNER := -m pytest tests --tb=native --no-flaky-report +# PLUGIN_TESTRUNNER := -m pytest pennylane/devices/tests --tb=native --no-flaky-report +# +# .PHONY: help +# help: +# @echo "Please use \`make ' where is one of" +# @echo " install to install PennyLane" +# @echo " wheel to build the PennyLane wheel" +# @echo " dist to package the source distribution" +# @echo " clean to delete all temporary, cache, and build files" +# @echo " docs to build the PennyLane documentation" +# @echo " clean-docs to delete all built documentation" +# @echo " test to run the test suite" +# @echo " coverage to generate a coverage report" +# @echo " format [check=1] to apply black formatter; use with 'check=1' to check instead of modify (requires black)" +# @echo " lint to run pylint on source files" +# @echo " lint-test to run pylint on test files" +# +# .PHONY: install +# install: +# ifndef PYTHON3 +# @echo "To install PennyLane you need to have Python 3 installed" +# endif +# $(PYTHON) setup.py install +# +# .PHONY: wheel +# wheel: +# $(PYTHON) setup.py bdist_wheel +# +# .PHONY: dist +# dist: +# $(PYTHON) setup.py sdist +# +# .PHONY : clean +# clean: +# rm -rf pennylane/__pycache__ +# rm -rf pennylane/optimize/__pycache__ +# rm -rf pennylane/expectation/__pycache__ +# rm -rf pennylane/ops/__pycache__ +# rm -rf pennylane/devices/__pycache__ +# rm -rf tests/__pycache__ +# rm -rf tests/new_qnode/__pycache__ +# rm -rf dist +# rm -rf build +# rm -rf .coverage coverage_html_report/ +# rm -rf tmp +# rm -rf *.dat +# +# docs: +# make -C doc html +# +# .PHONY : clean-docs +# clean-docs: +# rm -rf doc/code/api +# make -C doc clean +# +# test: +# $(PYTHON) $(TESTRUNNER) +# $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd +# +# coverage: +# @echo "Generating coverage report..." +# $(PYTHON) $(TESTRUNNER) $(COVERAGE) +# $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd $(COVERAGE) --cov-append +# +# .PHONY:format +# format: +# ifdef check +# isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check +# black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check +# else +# isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests +# black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests +# endif +# +# .PHONY: lint +# lint: +# pylint pennylane --rcfile .pylintrc +# +# .PHONY: lint-test +# lint-test: +# pylint tests pennylane/devices/tests --rcfile tests/.pylintrc From da2e63d82b9e617960990847223e65148239559e Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Fri, 5 Jul 2024 15:04:51 -0700 Subject: [PATCH 09/26] Removed commenting out Makefile --- Makefile | 172 +++++++++++++++++++++++++++---------------------------- 1 file changed, 86 insertions(+), 86 deletions(-) diff --git a/Makefile b/Makefile index 4e56417edba..5f7186e4549 100644 --- a/Makefile +++ b/Makefile @@ -1,86 +1,86 @@ -# PYTHON3 := $(shell which python3 2>/dev/null) -# -# PYTHON := python3 -# COVERAGE := --cov=pennylane --cov-report term-missing --cov-report=html:coverage_html_report -# TESTRUNNER := -m pytest tests --tb=native --no-flaky-report -# PLUGIN_TESTRUNNER := -m pytest pennylane/devices/tests --tb=native --no-flaky-report -# -# .PHONY: help -# help: -# @echo "Please use \`make ' where is one of" -# @echo " install to install PennyLane" -# @echo " wheel to build the PennyLane wheel" -# @echo " dist to package the source distribution" -# @echo " clean to delete all temporary, cache, and build files" -# @echo " docs to build the PennyLane documentation" -# @echo " clean-docs to delete all built documentation" -# @echo " test to run the test suite" -# @echo " coverage to generate a coverage report" -# @echo " format [check=1] to apply black formatter; use with 'check=1' to check instead of modify (requires black)" -# @echo " lint to run pylint on source files" -# @echo " lint-test to run pylint on test files" -# -# .PHONY: install -# install: -# ifndef PYTHON3 -# @echo "To install PennyLane you need to have Python 3 installed" -# endif -# $(PYTHON) setup.py install -# -# .PHONY: wheel -# wheel: -# $(PYTHON) setup.py bdist_wheel -# -# .PHONY: dist -# dist: -# $(PYTHON) setup.py sdist -# -# .PHONY : clean -# clean: -# rm -rf pennylane/__pycache__ -# rm -rf pennylane/optimize/__pycache__ -# rm -rf pennylane/expectation/__pycache__ -# rm -rf pennylane/ops/__pycache__ -# rm -rf pennylane/devices/__pycache__ -# rm -rf tests/__pycache__ -# rm -rf tests/new_qnode/__pycache__ -# rm -rf dist -# rm -rf build -# rm -rf .coverage coverage_html_report/ -# rm -rf tmp -# rm -rf *.dat -# -# docs: -# make -C doc html -# -# .PHONY : clean-docs -# clean-docs: -# rm -rf doc/code/api -# make -C doc clean -# -# test: -# $(PYTHON) $(TESTRUNNER) -# $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd -# -# coverage: -# @echo "Generating coverage report..." -# $(PYTHON) $(TESTRUNNER) $(COVERAGE) -# $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd $(COVERAGE) --cov-append -# -# .PHONY:format -# format: -# ifdef check -# isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check -# black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check -# else -# isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests -# black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests -# endif -# -# .PHONY: lint -# lint: -# pylint pennylane --rcfile .pylintrc -# -# .PHONY: lint-test -# lint-test: -# pylint tests pennylane/devices/tests --rcfile tests/.pylintrc +PYTHON3 := $(shell which python3 2>/dev/null) + +PYTHON := python3 +COVERAGE := --cov=pennylane --cov-report term-missing --cov-report=html:coverage_html_report +TESTRUNNER := -m pytest tests --tb=native --no-flaky-report +PLUGIN_TESTRUNNER := -m pytest pennylane/devices/tests --tb=native --no-flaky-report + +.PHONY: help +help: + @echo "Please use \`make ' where is one of" + @echo " install to install PennyLane" + @echo " wheel to build the PennyLane wheel" + @echo " dist to package the source distribution" + @echo " clean to delete all temporary, cache, and build files" + @echo " docs to build the PennyLane documentation" + @echo " clean-docs to delete all built documentation" + @echo " test to run the test suite" + @echo " coverage to generate a coverage report" + @echo " format [check=1] to apply black formatter; use with 'check=1' to check instead of modify (requires black)" + @echo " lint to run pylint on source files" + @echo " lint-test to run pylint on test files" + +.PHONY: install +install: +ifndef PYTHON3 + @echo "To install PennyLane you need to have Python 3 installed" +endif + $(PYTHON) setup.py install + +.PHONY: wheel +wheel: + $(PYTHON) setup.py bdist_wheel + +.PHONY: dist +dist: + $(PYTHON) setup.py sdist + +.PHONY : clean +clean: + rm -rf pennylane/__pycache__ + rm -rf pennylane/optimize/__pycache__ + rm -rf pennylane/expectation/__pycache__ + rm -rf pennylane/ops/__pycache__ + rm -rf pennylane/devices/__pycache__ + rm -rf tests/__pycache__ + rm -rf tests/new_qnode/__pycache__ + rm -rf dist + rm -rf build + rm -rf .coverage coverage_html_report/ + rm -rf tmp + rm -rf *.dat + +docs: + make -C doc html + +.PHONY : clean-docs +clean-docs: + rm -rf doc/code/api + make -C doc clean + +test: + $(PYTHON) $(TESTRUNNER) + $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd + +coverage: + @echo "Generating coverage report..." + $(PYTHON) $(TESTRUNNER) $(COVERAGE) + $(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd $(COVERAGE) --cov-append + +.PHONY:format +format: +ifdef check + isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check + black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check +else + isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests + black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests +endif + +.PHONY: lint +lint: + pylint pennylane --rcfile .pylintrc + +.PHONY: lint-test +lint-test: + pylint tests pennylane/devices/tests --rcfile tests/.pylintrc From 9044145e7f59c1e39629054d370546d85d22db0b Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Fri, 5 Jul 2024 16:21:47 -0700 Subject: [PATCH 10/26] Removed simulate file and added simulator to devices for more one to one --- pennylane/devices/default_mixed.py | 114 ++++++++-------- .../qtcorgi_helper/apply_operations.py | 5 +- .../qtcorgi_helper/qtcorgi_simulator.py | 123 ------------------ pennylane/devices/qutrit_mixed/simulate.py | 70 +++++++++- 4 files changed, 128 insertions(+), 184 deletions(-) delete mode 100644 pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 97c11febdf7..805ab28d9fc 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -48,7 +48,10 @@ from pennylane.wires import Wires from .._version import __version__ -from .qtcorgi_helper.qtcorgi_simulator import get_qubit_final_state_from_initial +import jax +import jax.numpy as jnp +from .qtcorgi_helper.apply_operations import qubit_branches + logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -57,6 +60,60 @@ tolerance = 1e-10 +def get_qubit_final_state_from_initial(operations, initial_state): + """ + TODO + + Args: + operations ():TODO + initial_state ():TODO + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] + for op in operations: + + wires = op.wires() + + if isinstance(op, Channel): + ops_type_indices[0].append(2) + ops_type_indices[1].append([].index(type(op))) + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op))) + elif len(wires) == 2: + ops_type_indices[0].append(1) + ops_type_indices[1].append(0) # Assume always CNOT + else: + raise ValueError("TODO") + + if len(wires) == 1: + wires = [wires[0], -1] + params = op.parameters + ([0] * (3 - op.num_params)) + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) + + ops_param[0].append(params[0]) + + ops_info = { + "type_index": jnp.array(ops_type_indices), + "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], + "params": [jnp.array(ops_param)], + } + + return jax.lax.scan( + lambda state, op_info: ( + jax.lax.switch(op_info["type_indices"][0], qubit_branches, state, op_info), + None, + ), + initial_state, + ops_info, + )[0] + + class DefaultMixed(QubitDevice): """Default qubit device for performing mixed-state computations in PennyLane. @@ -93,73 +150,18 @@ class DefaultMixed(QubitDevice): author = "Xanadu Inc." operations = { - "Identity", - "Snapshot", "BasisState", "QubitStateVector", "StatePrep", "QubitDensityMatrix", - "QubitUnitary", - "ControlledQubitUnitary", - "BlockEncode", - "MultiControlledX", - "DiagonalQubitUnitary", - "SpecialUnitary", - "PauliX", - "PauliY", - "PauliZ", - "MultiRZ", "Hadamard", - "S", - "T", - "SX", "CNOT", - "SWAP", - "ISWAP", - "CSWAP", - "Toffoli", - "CCZ", - "CY", - "CZ", - "CH", - "PhaseShift", - "PCPhase", - "ControlledPhaseShift", - "CPhaseShift00", - "CPhaseShift01", - "CPhaseShift10", "RX", "RY", "RZ", - "Rot", - "CRX", - "CRY", - "CRZ", - "CRot", "AmplitudeDamping", - "GeneralizedAmplitudeDamping", - "PhaseDamping", "DepolarizingChannel", "BitFlip", - "PhaseFlip", - "PauliError", - "ResetError", - "QubitChannel", - "SingleExcitation", - "SingleExcitationPlus", - "SingleExcitationMinus", - "DoubleExcitation", - "DoubleExcitationPlus", - "DoubleExcitationMinus", - "QubitCarry", - "QubitSum", - "OrbitalRotation", - "FermionicSWAP", - "QFT", - "ThermalRelaxationError", - "ECR", - "ParametrizedEvolution", - "GlobalPhase", } _reshape = staticmethod(qnp.reshape) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 0f6fc3ce216..60787c30f38 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -110,11 +110,12 @@ def apply_operation_einsum(kraus, wires, state): return jnp.einsum(einsum_indices, kraus, state, kraus_dagger) -def get_two_qubit_unitary_matrix(): +def get_two_qubit_unitary_matrix(param): + # TODO pass -def get_CNOT_matrix(params): +def get_CNOT_matrix(param): return jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py deleted file mode 100644 index a4073d30e07..00000000000 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ /dev/null @@ -1,123 +0,0 @@ -import jax -import jax.numpy as jnp -import pennylane as qml -from pennylane.operation import Channel -from .apply_operations import qubit_branches, qutrit_branches - -op_types = [] - - -def get_qubit_final_state_from_initial(operations, initial_state): - """ - TODO - - Args: - TODO - - Returns: - Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and - whether the state has a batch dimension. - - """ - ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] - for op in operations: - - wires = op.wires() - - if isinstance(op, Channel): - ops_type_indices[0].append(2) - ops_type_indices[1].append([].index(type(op))) - elif len(wires) == 1: - ops_type_indices[0].append(0) - ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op))) - elif len(wires) == 2: - ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always CNOT - else: - raise ValueError("TODO") - - if len(wires) == 1: - wires = [wires[0], -1] - params = op.parameters + ([0] * (3 - op.num_params)) - ops_wires[0].append(wires[0]) - ops_wires[1].append(wires[1]) - - ops_param[0].append(params[0]) - - ops_info = { - "type_index": jnp.array(ops_type_indices), - "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_param)], - } - - return jax.lax.scan( - lambda state, op_info: ( - jax.lax.switch(op_info["type_indices"][0], qubit_branches, state, op_info), - None, - ), - initial_state, - ops_info, - )[0] - - -def get_qutrit_final_state_from_initial(operations, initial_state): - """ - TODO - - Args: - TODO - - Returns: - Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and - whether the state has a batch dimension. - - """ - ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] - for op in operations: - - wires = op.wires() - - if isinstance(op, Channel): - ops_type_indices[0].append(2) - ops_type_indices[1].append( - [qml.QutritDepolarizingChannel, qml.QutritAmplitudeDamping, qml.TritFlip].index( - type(op) - ) - ) - params = op.parameters + ([0] * (3 - op.num_params)) - elif len(wires) == 1: - ops_type_indices[0].append(0) - ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) - if ops_type_indices[1][-1] == 3: - params = [0] + list(op.subspace) if op.subspace is not None else [0, 0] - else: - params = list(op.params) + list(op.subspace) - elif len(wires) == 2: - ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always TAdd - params = [0, 0, 0] - else: - raise ValueError("TODO") - ops_params[0].append(params[0]) - ops_params[1].append(params[1]) - ops_params[2].append(params[2]) - - if len(wires) == 1: - wires = [wires[0], -1] - ops_wires[0].append(wires[0]) - ops_wires[1].append(wires[1]) - - ops_info = { - "type_indices": jnp.array(ops_type_indices), - "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], - } - - return jax.lax.scan( - lambda state, op_info: ( - jax.lax.switch(op_info["type_indices"][0], qutrit_branches, state, op_info), - None, - ), - initial_state, - ops_info, - )[0] diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index d820477dc8c..13811cff468 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -18,12 +18,12 @@ import pennylane as qml from pennylane.typing import Result -from .apply_operation import apply_operation from .initialize_state import create_initial_state from .measure import measure from .sampling import measure_with_samples -from .utils import QUDIT_DIM -from ..qtcorgi_helper.qtcorgi_simulator import get_qutrit_final_state_from_initial +from ..qtcorgi_helper.apply_operations import qutrit_branches +import jax +import jax.numpy as jnp INTERFACE_TO_LIKE = { # map interfaces known by autoray to themselves @@ -46,6 +46,70 @@ } +def get_qutrit_final_state_from_initial(operations, initial_state): + """ + TODO + + Args: + operations ():TODO + initial_state ():TODO + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] + for op in operations: + + wires = op.wires() + + if isinstance(op, qml.operation.Channel): + ops_type_indices[0].append(2) + ops_type_indices[1].append( + [qml.QutritDepolarizingChannel, qml.QutritAmplitudeDamping, qml.TritFlip].index( + type(op) + ) + ) + params = op.parameters + ([0] * (3 - op.num_params)) + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) + if ops_type_indices[1][-1] == 3: + params = [0] + list(op.subspace) if op.subspace is not None else [0, 0] + else: + params = list(op.params) + list(op.subspace) + elif len(wires) == 2: + ops_type_indices[0].append(1) + ops_type_indices[1].append(0) # Assume always TAdd + params = [0, 0, 0] + else: + raise ValueError("TODO") + ops_params[0].append(params[0]) + ops_params[1].append(params[1]) + ops_params[2].append(params[2]) + + if len(wires) == 1: + wires = [wires[0], -1] + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) + + ops_info = { + "type_indices": jnp.array(ops_type_indices), + "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], + "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], + } + + return jax.lax.scan( + lambda state, op_info: ( + jax.lax.switch(op_info["type_indices"][0], qutrit_branches, state, op_info), + None, + ), + initial_state, + ops_info, + )[0] + + def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=None) -> Result: """ Perform the measurements required by the circuit on the provided state. From 223ed07159eb3d4142f82a9266f1b0b7cf27dfed Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Mon, 8 Jul 2024 16:17:02 -0700 Subject: [PATCH 11/26] [ci skip] Added support for Adjoint of TAdd --- pennylane/devices/default_qutrit_mixed.py | 11 ++++++++-- .../qtcorgi_helper/apply_operations.py | 21 +++++-------------- .../qtcorgi_helper/qtcorgi_simulator.py | 4 +--- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/pennylane/devices/default_qutrit_mixed.py b/pennylane/devices/default_qutrit_mixed.py index 0ade57c0c73..7c6563df336 100644 --- a/pennylane/devices/default_qutrit_mixed.py +++ b/pennylane/devices/default_qutrit_mixed.py @@ -71,7 +71,15 @@ def observable_stopping_condition(obs: qml.operation.Operator) -> bool: def stopping_condition(op: qml.operation.Operator) -> bool: """Specify whether an Operator object is supported by the device.""" - expected_set = DefaultQutrit.operations | {"Snapshot"} | channels + operations = { + "TAdd", + "Adjoint(TAdd)", + "THadamard", + "TRX", + "TRY", + "TRZ", + } + expected_set = operations | channels return op.name in expected_set @@ -292,7 +300,6 @@ def execute( circuits: QuantumTape_or_Batch, execution_config: ExecutionConfig = DefaultExecutionConfig, ) -> Result_or_ResultBatch: - interface = ( execution_config.interface if execution_config.gradient_method in {"best", "backprop", None} diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 0f6fc3ce216..b89fe5081f7 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -158,6 +158,9 @@ def apply_single_qubit_channel(state, op_info): else qml.THadamard.compute_matrix(subspace=params[1:]) ), ] + +two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix] + single_qutrit_channels = [ lambda params: qml.QutritDepolarizingChannel.compute_kraus_matrices(params[0]), lambda params: qml.QutritAmplitudeDamping.compute_kraus_matrices(*params), @@ -173,22 +176,8 @@ def apply_single_qutrit_unitary(state, op_info): def apply_two_qutrit_unitary(state, op_info): wires = op_info["wires"] - kraus_mat = [ - jnp.array( - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 1, 0, 0], - ] - ) - ] - return apply_operation_einsum(kraus_mat, wires, state) + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qutrits_ops)] + return apply_operation_einsum(kraus_mats, wires, state) def apply_single_qutrit_channel(state, op_info): diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index a4073d30e07..9f2ed93d5a7 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -21,7 +21,6 @@ def get_qubit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] for op in operations: - wires = op.wires() if isinstance(op, Channel): @@ -74,7 +73,6 @@ def get_qutrit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] for op in operations: - wires = op.wires() if isinstance(op, Channel): @@ -94,7 +92,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): params = list(op.params) + list(op.subspace) elif len(wires) == 2: ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always TAdd + ops_type_indices[1].append(["TAdd", "Adjoint(TAdd)"].index(op.name)) params = [0, 0, 0] else: raise ValueError("TODO") From c0f5a58cb7f92cb83ecd59fefd964f69f1fb9ae0 Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Mon, 15 Jul 2024 16:44:43 -0700 Subject: [PATCH 12/26] Worked on apply operations einsum, I think that generating indexing needs to be done differently --- .../qtcorgi_helper/apply_operations.py | 54 ++++++------ .../qtcorgi_helper/einsum_mapping_old.py | 83 +++++++++++++++++++ 2 files changed, 110 insertions(+), 27 deletions(-) create mode 100644 pennylane/devices/qtcorgi_helper/einsum_mapping_old.py diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 1e8e65989fc..d645d32680b 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -14,7 +14,7 @@ alphabet_array = np.array(list(alphabet)) -def get_einsum_mapping(wires, state): +def get_einsum_mapping_one_wire(wires, state, ): r"""Finds the indices for einsum to apply kraus operators to a mixed state Args: @@ -22,56 +22,57 @@ def get_einsum_mapping(wires, state): state (array[complex]): Input quantum state Returns: - str: Indices mapping that defines the einsum + tuple(tuple(int)): Indices mapping that defines the einsum """ num_ch_wires = len(wires) num_wires = int(len(qml.math.shape(state)) / 2) rho_dim = 2 * num_wires # Tensor indices of the state. For each qutrit, need an index for rows *and* columns - state_indices = alphabet[:rho_dim] + state_indices = (..., ) + tuple(range(rho_dim)) # TODO this may need to be an input from above # row indices of the quantum state affected by this operation - row_wires_list = wires.tolist() - row_indices = "".join(alphabet_array[row_wires_list].tolist()) + row_indices = tuple(wires) # column indices are shifted by the number of wires - col_wires_list = [w + num_wires for w in row_wires_list] - col_indices = "".join(alphabet_array[col_wires_list].tolist()) + col_indices = tuple(w + num_wires for w in wires) # TODO replace # indices in einsum must be replaced with new ones - new_row_indices = alphabet[rho_dim : rho_dim + num_ch_wires] - new_col_indices = alphabet[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires] + new_row_indices = tuple(range(rho_dim, rho_dim + num_ch_wires)) + new_col_indices = tuple(range(rho_dim + num_ch_wires, rho_dim + 2 * num_ch_wires)) # index for summation over Kraus operators - kraus_index = alphabet[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1] + kraus_index = (rho_dim + 2 * num_ch_wires,) # apply mapping function - op_1_indices = f"{kraus_index}{new_row_indices}{row_indices}" - op_2_indices = f"{kraus_index}{col_indices}{new_col_indices}" + op_1_indices = (...,) + kraus_index + new_row_indices + row_indices + op_2_indices = (...,) + kraus_index + col_indices + new_col_indices - new_state_indices = get_new_state_einsum_indices( + new_state_indices = (...,) + get_new_state_einsum_indices( old_indices=col_indices + row_indices, new_indices=new_col_indices + new_row_indices, state_indices=state_indices, ) - # index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef' - return f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}" + # index mapping for einsum, e.g., (...0,1,2,3), (...0,1,2,3), (...0,1,2,3), (...0,1,2,3) + return op_1_indices, state_indices, op_2_indices, new_state_indices def get_new_state_einsum_indices(old_indices, new_indices, state_indices): """Retrieves the einsum indices string for the new state Args: - old_indices (str): indices that are summed - new_indices (str): indices that must be replaced with sums - state_indices (str): indices of the original state + old_indices tuple(int): indices that are summed + new_indices tuple(int): indices that must be replaced with sums + state_indices tuple(int): indices of the original state Returns: - str: The einsum indices of the new state + tuple(int): The einsum indices of the new state """ - return reduce( - lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), + # for old, new in zip(old_indices, new_indices): + # for i in old_indices: + # old_indices[i] = jax.lax.cond() + return reduce( # TODO, redo + lambda old_indices, idx_pair: old_indices[idx_pair[0]], zip(old_indices, new_indices), state_indices, ) @@ -80,7 +81,7 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): QUDIT_DIM = 3 -def apply_operation_einsum(kraus, wires, state): +def apply_operation_einsum(kraus, wires, state, mapping_indices): r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the quantum state. For a unitary gate, there is a single Kraus operator. @@ -88,11 +89,12 @@ def apply_operation_einsum(kraus, wires, state): kraus (??): TODO wires state (array[complex]): Input quantum state + mapping_indices Returns: array[complex]: output_state """ - einsum_indices = get_einsum_mapping(wires, state) + op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping(wires, state) num_ch_wires = len(wires) @@ -100,14 +102,12 @@ def apply_operation_einsum(kraus, wires, state): kraus_shape = [len(kraus)] + [QUDIT_DIM] * num_ch_wires * 2 kraus = jnp.stack(kraus) - kraus_transpose = jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2)) - # Torch throws error if math.conj is used before stack - kraus_dagger = jnp.conj(kraus_transpose) + kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) kraus = jnp.reshape(kraus, kraus_shape) kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) - return jnp.einsum(einsum_indices, kraus, state, kraus_dagger) + return jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices) def get_two_qubit_unitary_matrix(param): diff --git a/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py b/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py new file mode 100644 index 00000000000..704ca1b9405 --- /dev/null +++ b/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py @@ -0,0 +1,83 @@ +import time +import jax +import jax.numpy as jnp +from jax.lax import scan + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platforms", "cpu") +import pennylane as qml +from string import ascii_letters as alphabet + +import numpy as np +from functools import partial, reduce + +alphabet_array = np.array(list(alphabet)) + + +def get_einsum_mapping(wires, state): + r"""Finds the indices for einsum to apply kraus operators to a mixed state + + Args: + wires + state (array[complex]): Input quantum state + + Returns: + str: Indices mapping that defines the einsum + """ + num_ch_wires = len(wires) + num_wires = int(len(qml.math.shape(state)) / 2) + rho_dim = 2 * num_wires + + # Tensor indices of the state. For each qutrit, need an index for rows *and* columns + state_indices = alphabet[:rho_dim] + + # row indices of the quantum state affected by this operation + row_wires_list = wires + row_indices = "".join(alphabet_array[row_wires_list].tolist()) + + # column indices are shifted by the number of wires + col_wires_list = [w + num_wires for w in row_wires_list] + col_indices = "".join(alphabet_array[col_wires_list].tolist()) + + # indices in einsum must be replaced with new ones + new_row_indices = alphabet[rho_dim : rho_dim + num_ch_wires] + new_col_indices = alphabet[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires] + + # index for summation over Kraus operators + kraus_index = alphabet[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1] + print(kraus_index) + + # apply mapping function + op_1_indices = f"{kraus_index}{new_row_indices}{row_indices}" + op_2_indices = f"{kraus_index}{col_indices}{new_col_indices}" + + new_state_indices = get_new_state_einsum_indices( + old_indices=col_indices + row_indices, + new_indices=new_col_indices + new_row_indices, + state_indices=state_indices, + ) + # index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef' + return f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}" + + +def get_new_state_einsum_indices(old_indices, new_indices, state_indices): + """Retrieves the einsum indices string for the new state + + Args: + old_indices (str): indices that are summed + new_indices (str): indices that must be replaced with sums + state_indices (str): indices of the original state + + Returns: + str: The einsum indices of the new state + """ + return reduce( + lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), + zip(old_indices, new_indices), + state_indices, + ) + + +QUDIT_DIM = 3 + +print(get_einsum_mapping([0,1], np.zeros((3,3,3,3,3,3,3,3)))) \ No newline at end of file From d969cd61be53f2e4f52cf42a172b2535415fdda4 Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Mon, 15 Jul 2024 17:22:19 -0700 Subject: [PATCH 13/26] Changed einsum indices to not index via scan abstract values --- .../qtcorgi_helper/apply_operations.py | 39 +++++++++++-------- .../qtcorgi_helper/einsum_mapping_old.py | 2 +- .../qtcorgi_helper/qtcorgi_simulator.py | 2 - 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index d645d32680b..47874dd4976 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -14,7 +14,7 @@ alphabet_array = np.array(list(alphabet)) -def get_einsum_mapping_one_wire(wires, state, ): +def get_einsum_mapping(wires, state): r"""Finds the indices for einsum to apply kraus operators to a mixed state Args: @@ -29,13 +29,16 @@ def get_einsum_mapping_one_wire(wires, state, ): rho_dim = 2 * num_wires # Tensor indices of the state. For each qutrit, need an index for rows *and* columns - state_indices = (..., ) + tuple(range(rho_dim)) # TODO this may need to be an input from above + # TODO this may need to be an input from above + state_indices_list = list(range(rho_dim)) + state_indices = (...,) + tuple(state_indices_list) # row indices of the quantum state affected by this operation row_indices = tuple(wires) # column indices are shifted by the number of wires - col_indices = tuple(w + num_wires for w in wires) # TODO replace + # TODO replace + col_indices = tuple(w + num_wires for w in wires) # TODO: Should I do an array? # indices in einsum must be replaced with new ones new_row_indices = tuple(range(rho_dim, rho_dim + num_ch_wires)) @@ -48,10 +51,10 @@ def get_einsum_mapping_one_wire(wires, state, ): op_1_indices = (...,) + kraus_index + new_row_indices + row_indices op_2_indices = (...,) + kraus_index + col_indices + new_col_indices - new_state_indices = (...,) + get_new_state_einsum_indices( + new_state_indices = get_new_state_einsum_indices( old_indices=col_indices + row_indices, new_indices=new_col_indices + new_row_indices, - state_indices=state_indices, + state_indices=state_indices_list, ) # index mapping for einsum, e.g., (...0,1,2,3), (...0,1,2,3), (...0,1,2,3), (...0,1,2,3) return op_1_indices, state_indices, op_2_indices, new_state_indices @@ -68,20 +71,21 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): Returns: tuple(int): The einsum indices of the new state """ - # for old, new in zip(old_indices, new_indices): - # for i in old_indices: - # old_indices[i] = jax.lax.cond() - return reduce( # TODO, redo - lambda old_indices, idx_pair: old_indices[idx_pair[0]], - zip(old_indices, new_indices), - state_indices, - ) + for old, new in zip(old_indices, new_indices): + for i, state_index in enumerate(state_indices): # TODO replace with jax.lax.scan + state_indices[i] = jax.lax.cond(old == state_index, lambda: new, lambda: state_index) + return (...,) + tuple(state_indices) + # return reduce( # TODO, redo + # lambda old_indices, idx_pair: old_indices[idx_pair[0]], + # zip(old_indices, new_indices), + # state_indices, + # ) QUDIT_DIM = 3 -def apply_operation_einsum(kraus, wires, state, mapping_indices): +def apply_operation_einsum(kraus, wires, state): r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the quantum state. For a unitary gate, there is a single Kraus operator. @@ -89,7 +93,6 @@ def apply_operation_einsum(kraus, wires, state, mapping_indices): kraus (??): TODO wires state (array[complex]): Input quantum state - mapping_indices Returns: array[complex]: output_state @@ -107,7 +110,9 @@ def apply_operation_einsum(kraus, wires, state, mapping_indices): kraus = jnp.reshape(kraus, kraus_shape) kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) - return jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices) + return jnp.einsum( + kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices + ) def get_two_qubit_unitary_matrix(param): @@ -115,7 +120,7 @@ def get_two_qubit_unitary_matrix(param): pass -def get_CNOT_matrix(param): +def get_CNOT_matrix(_param): return jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) diff --git a/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py b/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py index 704ca1b9405..965499c8478 100644 --- a/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py +++ b/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py @@ -80,4 +80,4 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): QUDIT_DIM = 3 -print(get_einsum_mapping([0,1], np.zeros((3,3,3,3,3,3,3,3)))) \ No newline at end of file +print(get_einsum_mapping([0, 1], np.zeros((3, 3, 3, 3, 3, 3, 3, 3)))) diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index a4073d30e07..f301edf63df 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -21,7 +21,6 @@ def get_qubit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] for op in operations: - wires = op.wires() if isinstance(op, Channel): @@ -74,7 +73,6 @@ def get_qutrit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] for op in operations: - wires = op.wires() if isinstance(op, Channel): From 2cb598581c6f7b65892175f99aafb8232cc2c642 Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Mon, 15 Jul 2024 18:16:18 -0700 Subject: [PATCH 14/26] Changed how different matrices are made --- .../qtcorgi_helper/apply_operations.py | 33 ++++++++++++++----- pennylane/devices/qutrit_mixed/simulate.py | 8 ++--- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 47874dd4976..71c9a42edf5 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -153,16 +153,30 @@ def apply_single_qubit_channel(state, op_info): qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] - -single_qutrit_ops = [ +single_qutrit_ops_subspace_0 = [None, None, None, lambda _param: qml.THadamard.compute_matrix()] +single_qutrit_ops_subspace_1 = [ qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix, - lambda params: ( - qml.THadamard.compute_matrix() - if params[1] != 0 - else qml.THadamard.compute_matrix(subspace=params[1:]) - ), + partial(qml.THadamard.compute_matrix, subspace=[0, 1]), +] +single_qutrit_ops_subspace_2 = [ + partial(qml.TRX.compute_matrix, subspace=[0, 2]), + partial(qml.TRY.compute_matrix, subspace=[0, 2]), + partial(qml.TRZ.compute_matrix, subspace=[0, 2]), + partial(qml.THadamard.compute_matrix, subspace=[0, 2]), +] +single_qutrit_ops_subspace_3 = [ + partial(qml.TRX.compute_matrix, subspace=[1, 2]), + partial(qml.TRY.compute_matrix, subspace=[1, 2]), + partial(qml.TRZ.compute_matrix, subspace=[1, 2]), + partial(qml.THadamard.compute_matrix, subspace=[0, 2]), +] +single_qutrit_ops = [ + single_qutrit_ops_subspace_0, + single_qutrit_ops_subspace_1, + single_qutrit_ops_subspace_2, + single_qutrit_ops_subspace_3, ] two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix] @@ -175,8 +189,9 @@ def apply_single_qubit_channel(state, op_info): def apply_single_qutrit_unitary(state, op_info): - wires, param = op_info["wires"][:0], op_info["params"][0] - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_ops, param)] + wires, param, subspace_index = op_info["wires"][:0], op_info["params"][0], op_info["params"][1] + mat_funcs = jax.lax.switch(subspace_index, single_qutrit_ops, param) + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], mat_funcs, param)] return apply_operation_einsum(kraus_mats, wires, state) diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 13811cff468..edbf74fb6df 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -61,8 +61,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] for op in operations: - - wires = op.wires() + wires = op.wires if isinstance(op, qml.operation.Channel): ops_type_indices[0].append(2) @@ -75,10 +74,11 @@ def get_qutrit_final_state_from_initial(operations, initial_state): elif len(wires) == 1: ops_type_indices[0].append(0) ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) + subspace_index = op.subspace.index([None, (0, 1), (0, 2), (1, 2)]) if ops_type_indices[1][-1] == 3: - params = [0] + list(op.subspace) if op.subspace is not None else [0, 0] + params = [0, subspace_index, 0] else: - params = list(op.params) + list(op.subspace) + params = list(op.parameters) + [subspace_index, 0] elif len(wires) == 2: ops_type_indices[0].append(1) ops_type_indices[1].append(0) # Assume always TAdd From aac0ed8f6fbbb3684516c827e4175aec614f53f9 Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Mon, 15 Jul 2024 19:18:43 -0700 Subject: [PATCH 15/26] Got to einsum, not working... --- .../qtcorgi_helper/apply_operations.py | 33 +++++++++---------- pennylane/devices/qutrit_mixed/simulate.py | 2 +- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 71c9a42edf5..7a16f054efa 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -134,7 +134,7 @@ def get_CNOT_matrix(_param): def apply_single_qubit_unitary(state, op_info): - wires, param = op_info["wires"][:0], op_info["params"][0] + wires, param = op_info["wires"][:1], op_info["params"][0] kraus_mat = jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param) return apply_operation_einsum(kraus_mat, wires, state) @@ -146,37 +146,36 @@ def apply_two_qubit_unitary(state, op_info): def apply_single_qubit_channel(state, op_info): - wires, param = op_info["wires"][:0], op_info["params"][0] + wires, param = op_info["wires"][:1], op_info["params"][0] kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qubit_channels, param)] return apply_operation_einsum(kraus_mats, wires, state) qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] -single_qutrit_ops_subspace_0 = [None, None, None, lambda _param: qml.THadamard.compute_matrix()] -single_qutrit_ops_subspace_1 = [ +single_qutrit_ops_subspace_01 = [ qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix, - partial(qml.THadamard.compute_matrix, subspace=[0, 1]), + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]), ] -single_qutrit_ops_subspace_2 = [ +single_qutrit_ops_subspace_02 = [ partial(qml.TRX.compute_matrix, subspace=[0, 2]), partial(qml.TRY.compute_matrix, subspace=[0, 2]), partial(qml.TRZ.compute_matrix, subspace=[0, 2]), - partial(qml.THadamard.compute_matrix, subspace=[0, 2]), + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), ] -single_qutrit_ops_subspace_3 = [ +single_qutrit_ops_subspace_12 = [ partial(qml.TRX.compute_matrix, subspace=[1, 2]), partial(qml.TRY.compute_matrix, subspace=[1, 2]), partial(qml.TRZ.compute_matrix, subspace=[1, 2]), - partial(qml.THadamard.compute_matrix, subspace=[0, 2]), + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), ] single_qutrit_ops = [ - single_qutrit_ops_subspace_0, - single_qutrit_ops_subspace_1, - single_qutrit_ops_subspace_2, - single_qutrit_ops_subspace_3, + lambda _op_type, _param: qml.THadamard.compute_matrix(), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_01, param), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_02, param), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_12, param), ] two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix] @@ -189,9 +188,9 @@ def apply_single_qubit_channel(state, op_info): def apply_single_qutrit_unitary(state, op_info): - wires, param, subspace_index = op_info["wires"][:0], op_info["params"][0], op_info["params"][1] - mat_funcs = jax.lax.switch(subspace_index, single_qutrit_ops, param) - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], mat_funcs, param)] + wires, param = op_info["wires"][:1], op_info["params"][0] + subspace_index, op_type = op_info["params"][1], op_info["type_indices"][1] + kraus_mats = [jax.lax.switch(subspace_index, single_qutrit_ops, op_type, param)] return apply_operation_einsum(kraus_mats, wires, state) @@ -202,7 +201,7 @@ def apply_two_qutrit_unitary(state, op_info): def apply_single_qutrit_channel(state, op_info): - wires, params = op_info["wires"][:0], op_info["params"] # TODO qutrit channels take 3 params + wires, params = op_info["wires"][:1], op_info["params"] # TODO qutrit channels take 3 params kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_channels, *params)] return apply_operation_einsum(kraus_mats, wires, state) diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index edbf74fb6df..3fcabca670d 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -74,7 +74,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): elif len(wires) == 1: ops_type_indices[0].append(0) ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) - subspace_index = op.subspace.index([None, (0, 1), (0, 2), (1, 2)]) + subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) if ops_type_indices[1][-1] == 3: params = [0, subspace_index, 0] else: From ebcaf57b69902118d15fc2c6cce9eef520b3a265 Mon Sep 17 00:00:00 2001 From: gabri Date: Wed, 24 Jul 2024 16:10:52 -0400 Subject: [PATCH 16/26] Fixed jittability speed for qutrit need to finish qubit --- .../qtcorgi_helper/apply_operations.py | 96 +++++++++++++++---- pennylane/devices/qutrit_mixed/simulate.py | 35 +++---- 2 files changed, 96 insertions(+), 35 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 7a16f054efa..63037621e90 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -14,6 +14,16 @@ alphabet_array = np.array(list(alphabet)) +def swap_axes(op, start, fin): + axes = jnp.arange(op.ndim) + for s,f in zip(start, fin): + axes = axes.at[f].set(s) + axes = axes.at[s].set(f) + indices = jnp.mgrid[tuple(slice(s) for s in op.shape)] #TODO can do + indices = indices[axes] + return op[tuple(indices[i] for i in range(indices.shape[0]))] + + def get_einsum_mapping(wires, state): r"""Finds the indices for einsum to apply kraus operators to a mixed state @@ -72,8 +82,7 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): tuple(int): The einsum indices of the new state """ for old, new in zip(old_indices, new_indices): - for i, state_index in enumerate(state_indices): # TODO replace with jax.lax.scan - state_indices[i] = jax.lax.cond(old == state_index, lambda: new, lambda: state_index) + state_indices[old] = new return (...,) + tuple(state_indices) # return reduce( # TODO, redo # lambda old_indices, idx_pair: old_indices[idx_pair[0]], @@ -85,6 +94,52 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): QUDIT_DIM = 3 +def apply_single_qudit_operation(kraus, wire, state): + num_wires = state.ndim // 2 + start, fin = (wire, wire+num_wires), (0, num_wires) + state = swap_axes(state, start, fin) + + # Shape kraus operators + kraus_shape = [len(kraus)] + [QUDIT_DIM] * 2 + + kraus = jnp.stack(kraus) + kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) + + kraus = jnp.reshape(kraus, kraus_shape) + kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) + op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping([0], state) # TODO fix + state = jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices) + return swap_axes(state, fin, start) + + +def get_swap_indices(num_wires): + return (0, num_wires, 1, 1 + num_wires) + + +def get_swap_indices_opposite(num_wires): + return (1, 1 + num_wires, 0, num_wires) + + +def apply_two_qudit_operation(kraus, wires, state): + num_wires = state.ndim//2 + start = (wires[0], wires[0]+num_wires, wires[1], wires[1]+num_wires) + fin = jax.lax.cond(wires[0] > wires[1], get_swap_indices, get_swap_indices_opposite, num_wires) + state = swap_axes(state, start, fin) + state = swap_axes(state, start, fin) + + # Shape kraus operators + kraus_shape = [len(kraus)] + [QUDIT_DIM] * 4 # 2 * num_wires = 4 + + kraus = jnp.stack(kraus) + kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) + + kraus = jnp.reshape(kraus, kraus_shape) + kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) + op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping([0, 1], state) + state = jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices) + return swap_axes(state, fin, start) + + def apply_operation_einsum(kraus, wires, state): r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the quantum state. For a unitary gate, there is a single Kraus operator. @@ -178,36 +233,41 @@ def apply_single_qubit_channel(state, op_info): lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_12, param), ] -two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix] - -single_qutrit_channels = [ - lambda params: qml.QutritDepolarizingChannel.compute_kraus_matrices(params[0]), - lambda params: qml.QutritAmplitudeDamping.compute_kraus_matrices(*params), - lambda params: qml.TritFlip.compute_kraus_matrices(*params), -] +two_qutrits_ops = [qml.TAdd.compute_matrix, lambda: jnp.conj(qml.TAdd.compute_matrix().T)] def apply_single_qutrit_unitary(state, op_info): - wires, param = op_info["wires"][:1], op_info["params"][0] - subspace_index, op_type = op_info["params"][1], op_info["type_indices"][1] + wire, param = op_info["wires"][0], op_info["params"][0] + subspace_index, op_type = op_info["wires"][1], op_info["type_indices"][1] kraus_mats = [jax.lax.switch(subspace_index, single_qutrit_ops, op_type, param)] - return apply_operation_einsum(kraus_mats, wires, state) + return apply_single_qudit_operation(kraus_mats, wire, state) def apply_two_qutrit_unitary(state, op_info): wires = op_info["wires"] kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qutrits_ops)] - return apply_operation_einsum(kraus_mats, wires, state) + return apply_two_qudit_operation(kraus_mats, wires, state) -def apply_single_qutrit_channel(state, op_info): - wires, params = op_info["wires"][:1], op_info["params"] # TODO qutrit channels take 3 params - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_channels, *params)] - return apply_operation_einsum(kraus_mats, wires, state) +def apply_qutrit_depolarizing_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["params"][0] + kraus_mats = qml.QutritDepolarizingChannel.compute_kraus_matrices(param) + return apply_single_qudit_operation(kraus_mats, wire, state) + + +def apply_qutrit_subspace_channel(state, op_info): + wire, params = op_info["wires"][0], op_info["params"] + print(params) + kraus_mats = jax.lax.cond(op_info["type_indices"][1] == 1, qml.QutritAmplitudeDamping.compute_kraus_matrices, qml.TritFlip.compute_kraus_matrices, *params) + return apply_single_qudit_operation(kraus_mats, wire, state) +def apply_single_qutrit_channel(state, op_info): + return jax.lax.cond(op_info["type_indices"][1] == 0, apply_qutrit_depolarizing_channel, + apply_qutrit_subspace_channel, state, op_info) + qutrit_branches = [ apply_single_qutrit_unitary, - apply_two_qutrit_unitary, apply_single_qutrit_channel, + apply_two_qutrit_unitary, ] diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 3fcabca670d..68f5a0e19a7 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -60,29 +60,33 @@ def get_qutrit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] + two_qutrit_ops = False for op in operations: wires = op.wires if isinstance(op, qml.operation.Channel): - ops_type_indices[0].append(2) + ops_type_indices[0].append(1) ops_type_indices[1].append( [qml.QutritDepolarizingChannel, qml.QutritAmplitudeDamping, qml.TritFlip].index( type(op) ) ) params = op.parameters + ([0] * (3 - op.num_params)) + wires = [wires[0], -1] elif len(wires) == 1: ops_type_indices[0].append(0) ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) if ops_type_indices[1][-1] == 3: - params = [0, subspace_index, 0] + params = [0., 0., 0.] else: - params = list(op.parameters) + [subspace_index, 0] + params = list(op.parameters) + [0., 0.] + wires = [wires[0], subspace_index] elif len(wires) == 2: - ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always TAdd - params = [0, 0, 0] + ops_type_indices[0].append(2) + ops_type_indices[1].append(0 if isinstance(op, qml.TAdd) else 1) # Always TAdd or adjoint + params = [0, 0., 0.] + two_qutrit_ops = True else: raise ValueError("TODO") ops_params[0].append(params[0]) @@ -95,19 +99,16 @@ def get_qutrit_final_state_from_initial(operations, initial_state): ops_wires[1].append(wires[1]) ops_info = { - "type_indices": jnp.array(ops_type_indices), + "type_indices": jnp.array(ops_type_indices).T, "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], } + branches = qutrit_branches[: 2 + two_qutrit_ops] + + def switch_function(state, op_info): + return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None - return jax.lax.scan( - lambda state, op_info: ( - jax.lax.switch(op_info["type_indices"][0], qutrit_branches, state, op_info), - None, - ), - initial_state, - ops_info, - )[0] + return jax.lax.scan(switch_function, initial_state, ops_info)[0] def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=None) -> Result: @@ -182,8 +183,8 @@ def get_final_state_qutrit(circuit, **kwargs): if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): prep = circuit[0] - state = create_initial_state(sorted(circuit.op_wires), prep, like="jax") - return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state), False + state = jnp.complex128(create_initial_state(sorted(circuit.op_wires), prep, like="jax")) + return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state) def simulate( From 1cf95ddedd815dffb3e25669438edf9fb9997f9e Mon Sep 17 00:00:00 2001 From: gabri Date: Thu, 25 Jul 2024 17:15:40 -0400 Subject: [PATCH 17/26] Added logic for most qubit gates --- pennylane/devices/default_mixed.py | 45 +++- .../qtcorgi_helper/apply_operations.py | 211 +++++++++--------- pennylane/devices/qutrit_mixed/simulate.py | 10 +- 3 files changed, 147 insertions(+), 119 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 805ab28d9fc..f604423d29e 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -74,34 +74,58 @@ def get_qubit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] + + two_gates = [ + "THadamard", + "TRX_01", + "TRY_01", + "TRZ_01", + "TRX_02", + "TRY_02", + "TRZ_02", + "TRX_12", + "TRY_12", + "TRZ_12", + ] for op in operations: - wires = op.wires() + wires = op.wires if isinstance(op, Channel): - ops_type_indices[0].append(2) - ops_type_indices[1].append([].index(type(op))) + ops_type_indices[0].append(1) + ops_type_indices[1].append( + [qml.DepolarizingChannel, qml.AmplitudeDamping, qml.BitFlip].index(type(op)) + ) elif len(wires) == 1: ops_type_indices[0].append(0) ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op))) elif len(wires) == 2: - ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always CNOT + ops_type_indices[0].append(2) + if isinstance(op, qml.CNOT): + op_index = 0 + else: + op_index = two_gates.index(op.id) + 1 + if op_index < 2: + ops_param.append(0) + elif op_index < 8: + ops_param.append(jnp.acos(op.matrix()[0, 0])) + else: + ops_param.append(jnp.acos(op.matrix()[1, 1])) + ops_type_indices[1].append(op_index) + else: raise ValueError("TODO") if len(wires) == 1: wires = [wires[0], -1] - params = op.parameters + ([0] * (3 - op.num_params)) + ops_param.append(op.parameters[0]) ops_wires[0].append(wires[0]) ops_wires[1].append(wires[1]) - ops_param[0].append(params[0]) - ops_info = { - "type_index": jnp.array(ops_type_indices), + "type_indices": jnp.array(ops_type_indices).T, "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_param)], + "param": [jnp.array(ops_param)], } return jax.lax.scan( @@ -794,6 +818,7 @@ def apply(self, operations, rotations=None, **kwargs): prep = True self._apply_operation(operation) + self._state = jnp.complex128(self._state) self._state = get_qubit_final_state_from_initial(operations[prep:], self._state) # store the pre-rotated state self._pre_rotated_state = self._state diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 63037621e90..3b5af3aec5b 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -16,10 +16,10 @@ def swap_axes(op, start, fin): axes = jnp.arange(op.ndim) - for s,f in zip(start, fin): + for s, f in zip(start, fin): axes = axes.at[f].set(s) axes = axes.at[s].set(f) - indices = jnp.mgrid[tuple(slice(s) for s in op.shape)] #TODO can do + indices = jnp.mgrid[tuple(slice(s) for s in op.shape)] # TODO can do indices = indices[axes] return op[tuple(indices[i] for i in range(indices.shape[0]))] @@ -84,147 +84,138 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): for old, new in zip(old_indices, new_indices): state_indices[old] = new return (...,) + tuple(state_indices) - # return reduce( # TODO, redo - # lambda old_indices, idx_pair: old_indices[idx_pair[0]], - # zip(old_indices, new_indices), - # state_indices, - # ) -QUDIT_DIM = 3 - - -def apply_single_qudit_operation(kraus, wire, state): - num_wires = state.ndim // 2 - start, fin = (wire, wire+num_wires), (0, num_wires) - state = swap_axes(state, start, fin) - - # Shape kraus operators - kraus_shape = [len(kraus)] + [QUDIT_DIM] * 2 - - kraus = jnp.stack(kraus) - kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) - - kraus = jnp.reshape(kraus, kraus_shape) - kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) - op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping([0], state) # TODO fix - state = jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices) - return swap_axes(state, fin, start) - - -def get_swap_indices(num_wires): - return (0, num_wires, 1, 1 + num_wires) - - -def get_swap_indices_opposite(num_wires): - return (1, 1 + num_wires, 0, num_wires) - - -def apply_two_qudit_operation(kraus, wires, state): - num_wires = state.ndim//2 - start = (wires[0], wires[0]+num_wires, wires[1], wires[1]+num_wires) - fin = jax.lax.cond(wires[0] > wires[1], get_swap_indices, get_swap_indices_opposite, num_wires) - state = swap_axes(state, start, fin) - state = swap_axes(state, start, fin) +def apply_operation_einsum(kraus, swap_inds, state, qudit_dim, num_wires): + state = swap_axes(state, *swap_inds) # Shape kraus operators - kraus_shape = [len(kraus)] + [QUDIT_DIM] * 4 # 2 * num_wires = 4 + kraus_shape = [len(kraus)] + ([qudit_dim] * num_wires * 2) kraus = jnp.stack(kraus) kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) kraus = jnp.reshape(kraus, kraus_shape) kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) - op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping([0, 1], state) - state = jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices) - return swap_axes(state, fin, start) - - -def apply_operation_einsum(kraus, wires, state): - r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the - quantum state. For a unitary gate, there is a single Kraus operator. - - Args: - kraus (??): TODO - wires - state (array[complex]): Input quantum state - - Returns: - array[complex]: output_state - """ - op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping(wires, state) - - num_ch_wires = len(wires) - - # Shape kraus operators - kraus_shape = [len(kraus)] + [QUDIT_DIM] * num_ch_wires * 2 - - kraus = jnp.stack(kraus) - kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) - - kraus = jnp.reshape(kraus, kraus_shape) - kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) - - return jnp.einsum( + op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping( + list(range(num_wires)), state + ) + state = jnp.einsum( kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices ) + return swap_axes(state, *swap_inds) -def get_two_qubit_unitary_matrix(param): - # TODO - pass +def apply_single_qudit_operation(kraus, wire, state, qudit_dim): + num_wires = state.ndim // 2 + swap_inds = (wire, wire + num_wires), (0, num_wires) + return apply_operation_einsum(kraus, swap_inds, state, qudit_dim, 1) -def get_CNOT_matrix(_param): - return jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) +def apply_two_qudit_operation(kraus, wires, state, qudit_dim): + num_wires = state.ndim // 2 + start = (wires[0], wires[0] + num_wires, wires[1], wires[1] + num_wires) + fin = jax.lax.cond( + wires[0] > wires[1], + lambda: (0, num_wires, 1, 1 + num_wires), + lambda: (1, 1 + num_wires, 0, num_wires), + ) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) -single_qubit_ops = [qml.RX.compute_matrix, qml.RY.compute_matrix, qml.RZ.compute_matrix] -two_qubit_ops = [get_CNOT_matrix, get_two_qubit_unitary_matrix] -single_qubit_channels = [ - qml.DepolarizingChannel.compute_kraus_matrices, - qml.AmplitudeDamping.compute_kraus_matrices, - qml.BitFlip.compute_kraus_matrices, +single_qubit_ops = [ + qml.RX.compute_matrix, + qml.RY.compute_matrix, + qml.RZ.compute_matrix, + lambda _param: qml.Hadamard.compute_matrix(), ] +qubits_qutrit_ops = [ + lambda _param: qml.THadamard.compute_matrix(), + qml.TRX.compute_matrix, + qml.TRY.compute_matrix, + qml.TRZ.compute_matrix, + partial(qml.TRX.compute_matrix, subspace=[0, 2]), + partial(qml.TRY.compute_matrix, subspace=[0, 2]), + partial(qml.TRZ.compute_matrix, subspace=[0, 2]), + partial(qml.TRX.compute_matrix, subspace=[1, 2]), + partial(qml.TRY.compute_matrix, subspace=[1, 2]), + partial(qml.TRZ.compute_matrix, subspace=[1, 2]), +] + + +def get_qutrit_op_as_qubits(param, op_type): + new_mat = jnp.eye(4) + new_mat[:3, :3] = jax.lax.switch(op_type - 1, qubits_qutrit_ops, param) + return new_mat def apply_single_qubit_unitary(state, op_info): - wires, param = op_info["wires"][:1], op_info["params"][0] - kraus_mat = jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param) - return apply_operation_einsum(kraus_mat, wires, state) + wires, param = op_info["wires"][:1], op_info["param"][0] + kraus_mat = [jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param)] + return apply_single_qudit_operation(kraus_mat, wires, state, 2) def apply_two_qubit_unitary(state, op_info): - wires, params = op_info["wires"], op_info["params"] - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qubit_ops, params)] - return apply_operation_einsum(kraus_mats, wires, state) + wires, param = op_info["wires"], op_info["param"] + op_type = op_info["type_indices"][1] + kraus_mats = [ + jax.lax.cond( + op_type == 0, + lambda *_args: qml.CNOT.compute_matrix, + get_qutrit_op_as_qubits, + param, + op_type, + ) + ] + return apply_two_qudit_operation(kraus_mats, wires, state, 2) + + +def apply_qubit_depolarizing_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["param"][0] + kraus_mats = qml.DepolarizingChannel.compute_kraus_matrices(param) + return apply_single_qudit_operation(kraus_mats, wire, state, 2) + + +def apply_qubit_flipping_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["param"] + kraus_mats = jax.lax.cond( + op_info["type_indices"][1] == 1, + qml.AmplitudeDamping.compute_kraus_matrices, + qml.BitFlip.compute_kraus_matrices, + param, + ) + return apply_single_qudit_operation(kraus_mats, wire, state, 2) def apply_single_qubit_channel(state, op_info): - wires, param = op_info["wires"][:1], op_info["params"][0] - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qubit_channels, param)] - return apply_operation_einsum(kraus_mats, wires, state) + return jax.lax.cond( + op_info["type_indices"][1] == 0, + apply_qubit_depolarizing_channel, + apply_qubit_flipping_channel, + state, + op_info, + ) qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] single_qutrit_ops_subspace_01 = [ + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]), qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix, - lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]), ] single_qutrit_ops_subspace_02 = [ + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), partial(qml.TRX.compute_matrix, subspace=[0, 2]), partial(qml.TRY.compute_matrix, subspace=[0, 2]), partial(qml.TRZ.compute_matrix, subspace=[0, 2]), - lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), ] single_qutrit_ops_subspace_12 = [ + lambda _param: qml.THadamard.compute_matrix(subspace=[1, 2]), partial(qml.TRX.compute_matrix, subspace=[1, 2]), partial(qml.TRY.compute_matrix, subspace=[1, 2]), partial(qml.TRZ.compute_matrix, subspace=[1, 2]), - lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), ] single_qutrit_ops = [ lambda _op_type, _param: qml.THadamard.compute_matrix(), @@ -240,31 +231,41 @@ def apply_single_qutrit_unitary(state, op_info): wire, param = op_info["wires"][0], op_info["params"][0] subspace_index, op_type = op_info["wires"][1], op_info["type_indices"][1] kraus_mats = [jax.lax.switch(subspace_index, single_qutrit_ops, op_type, param)] - return apply_single_qudit_operation(kraus_mats, wire, state) + return apply_single_qudit_operation(kraus_mats, wire, state, 3) def apply_two_qutrit_unitary(state, op_info): wires = op_info["wires"] kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qutrits_ops)] - return apply_two_qudit_operation(kraus_mats, wires, state) + return apply_two_qudit_operation(kraus_mats, wires, state, 3) def apply_qutrit_depolarizing_channel(state, op_info): wire, param = op_info["wires"][0], op_info["params"][0] kraus_mats = qml.QutritDepolarizingChannel.compute_kraus_matrices(param) - return apply_single_qudit_operation(kraus_mats, wire, state) + return apply_single_qudit_operation(kraus_mats, wire, state, 3) def apply_qutrit_subspace_channel(state, op_info): wire, params = op_info["wires"][0], op_info["params"] - print(params) - kraus_mats = jax.lax.cond(op_info["type_indices"][1] == 1, qml.QutritAmplitudeDamping.compute_kraus_matrices, qml.TritFlip.compute_kraus_matrices, *params) - return apply_single_qudit_operation(kraus_mats, wire, state) + kraus_mats = jax.lax.cond( + op_info["type_indices"][1] == 1, + qml.QutritAmplitudeDamping.compute_kraus_matrices, + qml.TritFlip.compute_kraus_matrices, + *params + ) + return apply_single_qudit_operation(kraus_mats, wire, state, 3) def apply_single_qutrit_channel(state, op_info): - return jax.lax.cond(op_info["type_indices"][1] == 0, apply_qutrit_depolarizing_channel, - apply_qutrit_subspace_channel, state, op_info) + return jax.lax.cond( + op_info["type_indices"][1] == 0, + apply_qutrit_depolarizing_channel, + apply_qutrit_subspace_channel, + state, + op_info, + ) + qutrit_branches = [ apply_single_qutrit_unitary, diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 68f5a0e19a7..4311e21ed77 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -78,14 +78,16 @@ def get_qutrit_final_state_from_initial(operations, initial_state): ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) if ops_type_indices[1][-1] == 3: - params = [0., 0., 0.] + params = [0.0, 0.0, 0.0] else: - params = list(op.parameters) + [0., 0.] + params = list(op.parameters) + [0.0, 0.0] wires = [wires[0], subspace_index] elif len(wires) == 2: ops_type_indices[0].append(2) - ops_type_indices[1].append(0 if isinstance(op, qml.TAdd) else 1) # Always TAdd or adjoint - params = [0, 0., 0.] + ops_type_indices[1].append( + 0 if isinstance(op, qml.TAdd) else 1 + ) # Always TAdd or adjoint + params = [0, 0.0, 0.0] two_qutrit_ops = True else: raise ValueError("TODO") From f25acc4f57b3aa0ecead6b10641eba6ceb5f2c02 Mon Sep 17 00:00:00 2001 From: gabri Date: Thu, 25 Jul 2024 23:46:32 -0400 Subject: [PATCH 18/26] Got everything working --- pennylane/devices/default_mixed.py | 17 +++-- .../qtcorgi_helper/apply_operations.py | 62 +++++++++++++------ pennylane/devices/qutrit_mixed/simulate.py | 7 +-- 3 files changed, 54 insertions(+), 32 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index f604423d29e..bc90b4f4372 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -87,6 +87,7 @@ def get_qubit_final_state_from_initial(operations, initial_state): "TRY_12", "TRZ_12", ] + two_qubit_ops = False for op in operations: wires = op.wires @@ -111,6 +112,7 @@ def get_qubit_final_state_from_initial(operations, initial_state): ops_param.append(jnp.acos(op.matrix()[0, 0])) else: ops_param.append(jnp.acos(op.matrix()[1, 1])) + two_qubit_ops = True ops_type_indices[1].append(op_index) else: @@ -125,17 +127,14 @@ def get_qubit_final_state_from_initial(operations, initial_state): ops_info = { "type_indices": jnp.array(ops_type_indices).T, "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "param": [jnp.array(ops_param)], + "param": jnp.array(ops_param), } + branches = qubit_branches[: 2 + two_qubit_ops] + + def switch_function(state, op_info): + return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None + return jax.lax.scan(switch_function, initial_state, ops_info)[0] - return jax.lax.scan( - lambda state, op_info: ( - jax.lax.switch(op_info["type_indices"][0], qubit_branches, state, op_info), - None, - ), - initial_state, - ops_info, - )[0] class DefaultMixed(QubitDevice): diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 3b5af3aec5b..1dbeebc9a20 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -14,9 +14,9 @@ alphabet_array = np.array(list(alphabet)) -def swap_axes(op, start, fin): +def swap_axes(op, start, fin, reverse=False): axes = jnp.arange(op.ndim) - for s, f in zip(start, fin): + for s, f in reversed(list(zip(start, fin))) if reverse else zip(start, fin): axes = axes.at[f].set(s) axes = axes.at[s].set(f) indices = jnp.mgrid[tuple(slice(s) for s in op.shape)] # TODO can do @@ -103,7 +103,7 @@ def apply_operation_einsum(kraus, swap_inds, state, qudit_dim, num_wires): state = jnp.einsum( kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices ) - return swap_axes(state, *swap_inds) + return swap_axes(state, *swap_inds, reverse=True) def apply_single_qudit_operation(kraus, wire, state, qudit_dim): @@ -112,22 +112,46 @@ def apply_single_qudit_operation(kraus, wire, state, qudit_dim): return apply_operation_einsum(kraus, swap_inds, state, qudit_dim, 1) + + def apply_two_qudit_operation(kraus, wires, state, qudit_dim): num_wires = state.ndim // 2 - start = (wires[0], wires[0] + num_wires, wires[1], wires[1] + num_wires) - fin = jax.lax.cond( - wires[0] > wires[1], - lambda: (0, num_wires, 1, 1 + num_wires), - lambda: (1, 1 + num_wires, 0, num_wires), - ) - return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + # wire_choice = (wires[0] == 1 * wires[1] == 0) + 2 * (wires[0] == 1 * wires[1] != 0) + 3 * (wires[0] != 1 * wires[1] == 0) + def apply_two_qudit_regular(): + start = (wires[0], wires[0] + num_wires, wires[1], wires[1] + num_wires) + fin = jax.lax.cond( + wires[0] < wires[1], + lambda: (0, num_wires, 1, 1 + num_wires), + lambda: (1, 1 + num_wires, 0, num_wires), + ) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + def apply_two_qudit_10(): + start = (1, 1 + num_wires) + fin = (0, num_wires) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + def apply_two_qudit_1x(): + start = (1, 1 + num_wires, wires[1], wires[1] + num_wires) + fin = (0, num_wires, 1, 1 + num_wires) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + def apply_two_qudit_x0(): + start = (0, num_wires, wires[0], wires[0] + num_wires) + fin = (1, 1 + num_wires, 0, num_wires) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + return jax.lax.cond(wires[0] == 1, + lambda w1: jax.lax.cond(w1==0, apply_two_qudit_10, apply_two_qudit_1x), + lambda w1: jax.lax.cond(w1==0, apply_two_qudit_x0, apply_two_qudit_regular), + wires[1]) single_qubit_ops = [ qml.RX.compute_matrix, qml.RY.compute_matrix, qml.RZ.compute_matrix, - lambda _param: qml.Hadamard.compute_matrix(), + lambda _param: jnp.complex128(qml.Hadamard.compute_matrix()), ] qubits_qutrit_ops = [ lambda _param: qml.THadamard.compute_matrix(), @@ -144,15 +168,14 @@ def apply_two_qudit_operation(kraus, wires, state, qudit_dim): def get_qutrit_op_as_qubits(param, op_type): - new_mat = jnp.eye(4) - new_mat[:3, :3] = jax.lax.switch(op_type - 1, qubits_qutrit_ops, param) - return new_mat + new_mat = jnp.eye(4, dtype=jnp.complex128) + return new_mat.at[:3, :3].set(jax.lax.switch(op_type - 1, qubits_qutrit_ops, param)) def apply_single_qubit_unitary(state, op_info): - wires, param = op_info["wires"][:1], op_info["param"][0] + wire, param = op_info["wires"][0], op_info["param"] kraus_mat = [jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param)] - return apply_single_qudit_operation(kraus_mat, wires, state, 2) + return apply_single_qudit_operation(kraus_mat, wire, state, 2) def apply_two_qubit_unitary(state, op_info): @@ -161,17 +184,18 @@ def apply_two_qubit_unitary(state, op_info): kraus_mats = [ jax.lax.cond( op_type == 0, - lambda *_args: qml.CNOT.compute_matrix, + lambda *_args: jnp.complex128(qml.CNOT.compute_matrix()), get_qutrit_op_as_qubits, param, op_type, ) ] + print(kraus_mats) return apply_two_qudit_operation(kraus_mats, wires, state, 2) def apply_qubit_depolarizing_channel(state, op_info): - wire, param = op_info["wires"][0], op_info["param"][0] + wire, param = op_info["wires"][0], op_info["param"] kraus_mats = qml.DepolarizingChannel.compute_kraus_matrices(param) return apply_single_qudit_operation(kraus_mats, wire, state, 2) @@ -197,7 +221,7 @@ def apply_single_qubit_channel(state, op_info): ) -qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] +qubit_branches = [apply_single_qubit_unitary, apply_single_qubit_channel, apply_two_qubit_unitary] single_qutrit_ops_subspace_01 = [ lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]), diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 4311e21ed77..7c694997e33 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -75,7 +75,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): wires = [wires[0], -1] elif len(wires) == 1: ops_type_indices[0].append(0) - ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) + ops_type_indices[1].append([qml.THadamard, qml.TRX, qml.TRY, qml.TRZ].index(type(op))) subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) if ops_type_indices[1][-1] == 3: params = [0.0, 0.0, 0.0] @@ -86,8 +86,8 @@ def get_qutrit_final_state_from_initial(operations, initial_state): ops_type_indices[0].append(2) ops_type_indices[1].append( 0 if isinstance(op, qml.TAdd) else 1 - ) # Always TAdd or adjoint - params = [0, 0.0, 0.0] + ) + params = [0., 0., 0.] two_qutrit_ops = True else: raise ValueError("TODO") @@ -109,7 +109,6 @@ def get_qutrit_final_state_from_initial(operations, initial_state): def switch_function(state, op_info): return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None - return jax.lax.scan(switch_function, initial_state, ops_info)[0] From 9166ea1d0f97753334571bbbb6493a327b3c2a38 Mon Sep 17 00:00:00 2001 From: gabri Date: Thu, 25 Jul 2024 23:56:16 -0400 Subject: [PATCH 19/26] Reformatted --- pennylane/devices/default_mixed.py | 2 +- .../devices/qtcorgi_helper/apply_operations.py | 13 +++++++------ pennylane/devices/qutrit_mixed/simulate.py | 7 +++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index bc90b4f4372..04fbb7ca85a 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -133,8 +133,8 @@ def get_qubit_final_state_from_initial(operations, initial_state): def switch_function(state, op_info): return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None - return jax.lax.scan(switch_function, initial_state, ops_info)[0] + return jax.lax.scan(switch_function, initial_state, ops_info)[0] class DefaultMixed(QubitDevice): diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 1dbeebc9a20..f273f4989ee 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -112,10 +112,9 @@ def apply_single_qudit_operation(kraus, wire, state, qudit_dim): return apply_operation_einsum(kraus, swap_inds, state, qudit_dim, 1) - - def apply_two_qudit_operation(kraus, wires, state, qudit_dim): num_wires = state.ndim // 2 + # wire_choice = (wires[0] == 1 * wires[1] == 0) + 2 * (wires[0] == 1 * wires[1] != 0) + 3 * (wires[0] != 1 * wires[1] == 0) def apply_two_qudit_regular(): start = (wires[0], wires[0] + num_wires, wires[1], wires[1] + num_wires) @@ -141,10 +140,12 @@ def apply_two_qudit_x0(): fin = (1, 1 + num_wires, 0, num_wires) return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) - return jax.lax.cond(wires[0] == 1, - lambda w1: jax.lax.cond(w1==0, apply_two_qudit_10, apply_two_qudit_1x), - lambda w1: jax.lax.cond(w1==0, apply_two_qudit_x0, apply_two_qudit_regular), - wires[1]) + return jax.lax.cond( + wires[0] == 1, + lambda w1: jax.lax.cond(w1 == 0, apply_two_qudit_10, apply_two_qudit_1x), + lambda w1: jax.lax.cond(w1 == 0, apply_two_qudit_x0, apply_two_qudit_regular), + wires[1], + ) single_qubit_ops = [ diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 7c694997e33..0d00875c382 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -84,10 +84,8 @@ def get_qutrit_final_state_from_initial(operations, initial_state): wires = [wires[0], subspace_index] elif len(wires) == 2: ops_type_indices[0].append(2) - ops_type_indices[1].append( - 0 if isinstance(op, qml.TAdd) else 1 - ) - params = [0., 0., 0.] + ops_type_indices[1].append(0 if isinstance(op, qml.TAdd) else 1) + params = [0.0, 0.0, 0.0] two_qutrit_ops = True else: raise ValueError("TODO") @@ -109,6 +107,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): def switch_function(state, op_info): return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None + return jax.lax.scan(switch_function, initial_state, ops_info)[0] From 5c761cac69382ed529b7baa75d801cc82fcbe79a Mon Sep 17 00:00:00 2001 From: gabri Date: Fri, 26 Jul 2024 15:21:53 -0400 Subject: [PATCH 20/26] Added qubitunitary to operators --- pennylane/devices/default_mixed.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 04fbb7ca85a..b1d72071561 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -59,6 +59,18 @@ ABC_ARRAY = np.array(list(ABC)) tolerance = 1e-10 +two_gates = [ + "THadamard", + "TRX_01", + "TRY_01", + "TRZ_01", + "TRX_02", + "TRY_02", + "TRZ_02", + "TRX_12", + "TRY_12", + "TRZ_12", +] def get_qubit_final_state_from_initial(operations, initial_state): """ @@ -75,18 +87,7 @@ def get_qubit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] - two_gates = [ - "THadamard", - "TRX_01", - "TRY_01", - "TRZ_01", - "TRX_02", - "TRY_02", - "TRZ_02", - "TRX_12", - "TRY_12", - "TRZ_12", - ] + two_qubit_ops = False for op in operations: @@ -179,6 +180,7 @@ class DefaultMixed(QubitDevice): "QubitDensityMatrix", "Hadamard", "CNOT", + "QubitUnitary", "RX", "RY", "RZ", From cd332c0977a4a15d945efce434713c01255b064b Mon Sep 17 00:00:00 2001 From: gabri Date: Fri, 26 Jul 2024 15:38:21 -0400 Subject: [PATCH 21/26] Removed print --- pennylane/devices/qtcorgi_helper/apply_operations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index f273f4989ee..2ec6536b2b2 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -191,7 +191,6 @@ def apply_two_qubit_unitary(state, op_info): op_type, ) ] - print(kraus_mats) return apply_two_qudit_operation(kraus_mats, wires, state, 2) From 73160f470f20ff60d4243ae7d8f3a17f4c49b8a5 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Mon, 29 Jul 2024 15:25:18 -0700 Subject: [PATCH 22/26] Removed unused function from simulate --- .../qtcorgi_helper/qtcorgi_simulator.py | 52 ------------------- 1 file changed, 52 deletions(-) diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py index f301edf63df..a0756a0ea01 100644 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py @@ -7,58 +7,6 @@ op_types = [] -def get_qubit_final_state_from_initial(operations, initial_state): - """ - TODO - - Args: - TODO - - Returns: - Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and - whether the state has a batch dimension. - - """ - ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] - for op in operations: - wires = op.wires() - - if isinstance(op, Channel): - ops_type_indices[0].append(2) - ops_type_indices[1].append([].index(type(op))) - elif len(wires) == 1: - ops_type_indices[0].append(0) - ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op))) - elif len(wires) == 2: - ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always CNOT - else: - raise ValueError("TODO") - - if len(wires) == 1: - wires = [wires[0], -1] - params = op.parameters + ([0] * (3 - op.num_params)) - ops_wires[0].append(wires[0]) - ops_wires[1].append(wires[1]) - - ops_param[0].append(params[0]) - - ops_info = { - "type_index": jnp.array(ops_type_indices), - "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_param)], - } - - return jax.lax.scan( - lambda state, op_info: ( - jax.lax.switch(op_info["type_indices"][0], qubit_branches, state, op_info), - None, - ), - initial_state, - ops_info, - )[0] - - def get_qutrit_final_state_from_initial(operations, initial_state): """ TODO From a16997f28e6b0c0a3c4af34129a9737158472dc1 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Mon, 29 Jul 2024 15:27:04 -0700 Subject: [PATCH 23/26] Removed unused module qtcorgi simulator --- .../qtcorgi_helper/qtcorgi_simulator.py | 69 ------------------- 1 file changed, 69 deletions(-) delete mode 100644 pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py diff --git a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py b/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py deleted file mode 100644 index a0756a0ea01..00000000000 --- a/pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py +++ /dev/null @@ -1,69 +0,0 @@ -import jax -import jax.numpy as jnp -import pennylane as qml -from pennylane.operation import Channel -from .apply_operations import qubit_branches, qutrit_branches - -op_types = [] - - -def get_qutrit_final_state_from_initial(operations, initial_state): - """ - TODO - - Args: - TODO - - Returns: - Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and - whether the state has a batch dimension. - - """ - ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] - for op in operations: - wires = op.wires() - - if isinstance(op, Channel): - ops_type_indices[0].append(2) - ops_type_indices[1].append( - [qml.QutritDepolarizingChannel, qml.QutritAmplitudeDamping, qml.TritFlip].index( - type(op) - ) - ) - params = op.parameters + ([0] * (3 - op.num_params)) - elif len(wires) == 1: - ops_type_indices[0].append(0) - ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) - if ops_type_indices[1][-1] == 3: - params = [0] + list(op.subspace) if op.subspace is not None else [0, 0] - else: - params = list(op.params) + list(op.subspace) - elif len(wires) == 2: - ops_type_indices[0].append(1) - ops_type_indices[1].append(0) # Assume always TAdd - params = [0, 0, 0] - else: - raise ValueError("TODO") - ops_params[0].append(params[0]) - ops_params[1].append(params[1]) - ops_params[2].append(params[2]) - - if len(wires) == 1: - wires = [wires[0], -1] - ops_wires[0].append(wires[0]) - ops_wires[1].append(wires[1]) - - ops_info = { - "type_indices": jnp.array(ops_type_indices), - "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], - "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], - } - - return jax.lax.scan( - lambda state, op_info: ( - jax.lax.switch(op_info["type_indices"][0], qutrit_branches, state, op_info), - None, - ), - initial_state, - ops_info, - )[0] From 5e1229ea3b182898e14349c46f2e2b976776b7dc Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Mon, 29 Jul 2024 15:56:57 -0700 Subject: [PATCH 24/26] Added jitting to more functions --- pennylane/devices/default_mixed.py | 3 ++- .../qtcorgi_helper/apply_operations.py | 19 +++++++++++++++++++ pennylane/devices/qutrit_mixed/simulate.py | 1 + 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index b1d72071561..195f4c1ad32 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -72,6 +72,7 @@ "TRZ_12", ] + def get_qubit_final_state_from_initial(operations, initial_state): """ TODO @@ -87,7 +88,6 @@ def get_qubit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] - two_qubit_ops = False for op in operations: @@ -132,6 +132,7 @@ def get_qubit_final_state_from_initial(operations, initial_state): } branches = qubit_branches[: 2 + two_qubit_ops] + @jax.jit def switch_function(state, op_info): return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 2ec6536b2b2..5f2ef6564ca 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -14,6 +14,7 @@ alphabet_array = np.array(list(alphabet)) +@partial(jax.jit, static_argnames=["reverse"]) def swap_axes(op, start, fin, reverse=False): axes = jnp.arange(op.ndim) for s, f in reversed(list(zip(start, fin))) if reverse else zip(start, fin): @@ -86,6 +87,7 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices): return (...,) + tuple(state_indices) +# @jax.jit def apply_operation_einsum(kraus, swap_inds, state, qudit_dim, num_wires): state = swap_axes(state, *swap_inds) @@ -106,16 +108,19 @@ def apply_operation_einsum(kraus, swap_inds, state, qudit_dim, num_wires): return swap_axes(state, *swap_inds, reverse=True) +@partial(jax.jit, static_argnames=["qudit_dim"]) def apply_single_qudit_operation(kraus, wire, state, qudit_dim): num_wires = state.ndim // 2 swap_inds = (wire, wire + num_wires), (0, num_wires) return apply_operation_einsum(kraus, swap_inds, state, qudit_dim, 1) +@partial(jax.jit, static_argnames=["qudit_dim"]) def apply_two_qudit_operation(kraus, wires, state, qudit_dim): num_wires = state.ndim // 2 # wire_choice = (wires[0] == 1 * wires[1] == 0) + 2 * (wires[0] == 1 * wires[1] != 0) + 3 * (wires[0] != 1 * wires[1] == 0) + @jax.jit def apply_two_qudit_regular(): start = (wires[0], wires[0] + num_wires, wires[1], wires[1] + num_wires) fin = jax.lax.cond( @@ -125,16 +130,19 @@ def apply_two_qudit_regular(): ) return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + @jax.jit def apply_two_qudit_10(): start = (1, 1 + num_wires) fin = (0, num_wires) return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + @jax.jit def apply_two_qudit_1x(): start = (1, 1 + num_wires, wires[1], wires[1] + num_wires) fin = (0, num_wires, 1, 1 + num_wires) return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + @jax.jit def apply_two_qudit_x0(): start = (0, num_wires, wires[0], wires[0] + num_wires) fin = (1, 1 + num_wires, 0, num_wires) @@ -168,17 +176,20 @@ def apply_two_qudit_x0(): ] +@jax.jit def get_qutrit_op_as_qubits(param, op_type): new_mat = jnp.eye(4, dtype=jnp.complex128) return new_mat.at[:3, :3].set(jax.lax.switch(op_type - 1, qubits_qutrit_ops, param)) +@jax.jit def apply_single_qubit_unitary(state, op_info): wire, param = op_info["wires"][0], op_info["param"] kraus_mat = [jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param)] return apply_single_qudit_operation(kraus_mat, wire, state, 2) +@jax.jit def apply_two_qubit_unitary(state, op_info): wires, param = op_info["wires"], op_info["param"] op_type = op_info["type_indices"][1] @@ -194,12 +205,14 @@ def apply_two_qubit_unitary(state, op_info): return apply_two_qudit_operation(kraus_mats, wires, state, 2) +@jax.jit def apply_qubit_depolarizing_channel(state, op_info): wire, param = op_info["wires"][0], op_info["param"] kraus_mats = qml.DepolarizingChannel.compute_kraus_matrices(param) return apply_single_qudit_operation(kraus_mats, wire, state, 2) +@jax.jit def apply_qubit_flipping_channel(state, op_info): wire, param = op_info["wires"][0], op_info["param"] kraus_mats = jax.lax.cond( @@ -211,6 +224,7 @@ def apply_qubit_flipping_channel(state, op_info): return apply_single_qudit_operation(kraus_mats, wire, state, 2) +@jax.jit def apply_single_qubit_channel(state, op_info): return jax.lax.cond( op_info["type_indices"][1] == 0, @@ -251,6 +265,7 @@ def apply_single_qubit_channel(state, op_info): two_qutrits_ops = [qml.TAdd.compute_matrix, lambda: jnp.conj(qml.TAdd.compute_matrix().T)] +@jax.jit def apply_single_qutrit_unitary(state, op_info): wire, param = op_info["wires"][0], op_info["params"][0] subspace_index, op_type = op_info["wires"][1], op_info["type_indices"][1] @@ -258,18 +273,21 @@ def apply_single_qutrit_unitary(state, op_info): return apply_single_qudit_operation(kraus_mats, wire, state, 3) +@jax.jit def apply_two_qutrit_unitary(state, op_info): wires = op_info["wires"] kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qutrits_ops)] return apply_two_qudit_operation(kraus_mats, wires, state, 3) +@jax.jit def apply_qutrit_depolarizing_channel(state, op_info): wire, param = op_info["wires"][0], op_info["params"][0] kraus_mats = qml.QutritDepolarizingChannel.compute_kraus_matrices(param) return apply_single_qudit_operation(kraus_mats, wire, state, 3) +@jax.jit def apply_qutrit_subspace_channel(state, op_info): wire, params = op_info["wires"][0], op_info["params"] kraus_mats = jax.lax.cond( @@ -281,6 +299,7 @@ def apply_qutrit_subspace_channel(state, op_info): return apply_single_qudit_operation(kraus_mats, wire, state, 3) +@jax.jit def apply_single_qutrit_channel(state, op_info): return jax.lax.cond( op_info["type_indices"][1] == 0, diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 0d00875c382..561ff721ea7 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -105,6 +105,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): } branches = qutrit_branches[: 2 + two_qutrit_ops] + @jax.jit def switch_function(state, op_info): return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None From 9a0277d035e5fc6ca3e5dc5075bc067815458c2e Mon Sep 17 00:00:00 2001 From: gabrielLydian Date: Tue, 30 Jul 2024 17:37:58 -0700 Subject: [PATCH 25/26] Fixed index of THadamard param --- pennylane/devices/preprocess.py | 1 - pennylane/devices/qubit/sampling.py | 1 - pennylane/devices/qutrit_mixed/simulate.py | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 819e375df54..22b7d229f38 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -357,7 +357,6 @@ def decomposer(op): if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]): return (tape,), null_postprocessing try: - new_ops = [ final_op for op in tape.operations[len(prep_op) :] diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index 2d87d3e504e..7c3f5f6a805 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -132,7 +132,6 @@ def _get_num_wire_groups_for_expval_H(obs): def _get_num_executions_for_sum(obs): - if obs.grouping_indices: return len(obs.grouping_indices) diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 05a9be82715..5594870346b 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -77,7 +77,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): ops_type_indices[0].append(0) ops_type_indices[1].append([qml.THadamard, qml.TRX, qml.TRY, qml.TRZ].index(type(op))) subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) - if ops_type_indices[1][-1] == 3: + if ops_type_indices[1][-1] == 0: params = [0.0, 0.0, 0.0] else: params = list(op.parameters) + [0.0, 0.0] From 7a0653c6e4fdd01a66ec855cf53f3fb9931c0826 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill Date: Wed, 31 Jul 2024 14:48:34 -0700 Subject: [PATCH 26/26] Changed method for dealing with TRX,TRY, and TRZ ops --- pennylane/devices/default_mixed.py | 6 +- .../qtcorgi_helper/apply_operations.py | 124 +++++++++++++++--- 2 files changed, 106 insertions(+), 24 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 195f4c1ad32..40d751958a3 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -108,11 +108,9 @@ def get_qubit_final_state_from_initial(operations, initial_state): else: op_index = two_gates.index(op.id) + 1 if op_index < 2: - ops_param.append(0) - elif op_index < 8: - ops_param.append(jnp.acos(op.matrix()[0, 0])) + ops_param.append(0.0) else: - ops_param.append(jnp.acos(op.matrix()[1, 1])) + ops_param.append(op.phi) two_qubit_ops = True ops_type_indices[1].append(op_index) diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 5f2ef6564ca..5b80bb32d22 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -12,6 +12,7 @@ from functools import partial, reduce alphabet_array = np.array(list(alphabet)) +stack_last = partial(qml.math.stack, axis=-1) @partial(jax.jit, static_argnames=["reverse"]) @@ -162,17 +163,108 @@ def apply_two_qudit_x0(): qml.RZ.compute_matrix, lambda _param: jnp.complex128(qml.Hadamard.compute_matrix()), ] -qubits_qutrit_ops = [ - lambda _param: qml.THadamard.compute_matrix(), - qml.TRX.compute_matrix, - qml.TRY.compute_matrix, - qml.TRZ.compute_matrix, - partial(qml.TRX.compute_matrix, subspace=[0, 2]), - partial(qml.TRY.compute_matrix, subspace=[0, 2]), - partial(qml.TRZ.compute_matrix, subspace=[0, 2]), - partial(qml.TRX.compute_matrix, subspace=[1, 2]), - partial(qml.TRY.compute_matrix, subspace=[1, 2]), - partial(qml.TRZ.compute_matrix, subspace=[1, 2]), + + +def get_qubit_TRX_matrix_func(subspace): + def qubit_TRX_matrix(theta): + c = qml.math.cos(theta / 2) + s = qml.math.sin(theta / 2) + + # The following avoids casting an imaginary quantity to reals when backpropagating + c = (1 + 0j) * c + js = -1j * s + one = qml.math.ones_like(c) + z = qml.math.zeros_like(c) + + diags = [one, one, one] + diags[subspace[0]] = c + diags[subspace[1]] = c + + off_diags = [z, z, z] + off_diags[qml.math.sum(subspace) - 1] = js + + return qml.math.stack( + [ + stack_last([diags[0], off_diags[0], off_diags[1], z]), + stack_last([off_diags[0], diags[1], off_diags[2], z]), + stack_last([off_diags[1], off_diags[2], diags[2], z]), + stack_last([z, z, z, one]), + ], + axis=-2, + ) + + return qubit_TRX_matrix + + +def get_qubit_TRY_matrix_func(subspace): + def qubit_TRY_matrix(theta): + c = qml.math.cos(theta / 2) + s = qml.math.sin(theta / 2) + + # The following avoids casting an imaginary quantity to reals when backpropagating + c = (1 + 0j) * c + s = (1 + 0j) * s + one = qml.math.ones_like(c) + z = qml.math.zeros_like(c) + + diags = [one, one, one] + diags[subspace[0]] = c + diags[subspace[1]] = c + + off_diags = [z, z, z] + off_diags[qml.math.sum(subspace) - 1] = s + + return qml.math.stack( + [ + stack_last([diags[0], -off_diags[0], -off_diags[1], z]), + stack_last([off_diags[0], diags[1], -off_diags[2], z]), + stack_last([off_diags[1], off_diags[2], diags[2], z]), + stack_last([z, z, z, one]), + ], + axis=-2, + ) + + return qubit_TRY_matrix + + +def get_qubit_TRZ_matrix_func(subspace): + def qubit_TRZ_matrix(theta): + p = qml.math.exp(-1j * theta / 2) + one = qml.math.ones_like(p) + z = qml.math.zeros_like(p) + + diags = [one, one, one] + diags[subspace[0]] = p + diags[subspace[1]] = qml.math.conj(p) + + return qml.math.stack( + [ + stack_last([diags[0], z, z, z]), + stack_last([z, diags[1], z, z]), + stack_last([z, z, diags[2], z]), + stack_last([z, z, z, one]), + ], + axis=-2, + ) + + return qubit_TRZ_matrix + + +OMEGA = np.exp(2 * np.pi * 1j / 3) + +two_qubit_ops = [ + lambda _param: jnp.complex128(qml.CNOT.compute_matrix()), + lambda _param: (-1j / np.sqrt(3)) + * np.array([[1, 1, 1, 0], [1, OMEGA, OMEGA**2, 0], [1, OMEGA**2, OMEGA, 0], [0, 0, 0, 1]]), + get_qubit_TRX_matrix_func([0, 1]), + get_qubit_TRY_matrix_func([0, 1]), + get_qubit_TRZ_matrix_func([0, 1]), + get_qubit_TRX_matrix_func([0, 2]), + get_qubit_TRY_matrix_func([0, 2]), + get_qubit_TRZ_matrix_func([0, 2]), + get_qubit_TRX_matrix_func([1, 2]), + get_qubit_TRY_matrix_func([1, 2]), + get_qubit_TRZ_matrix_func([1, 2]), ] @@ -193,15 +285,7 @@ def apply_single_qubit_unitary(state, op_info): def apply_two_qubit_unitary(state, op_info): wires, param = op_info["wires"], op_info["param"] op_type = op_info["type_indices"][1] - kraus_mats = [ - jax.lax.cond( - op_type == 0, - lambda *_args: jnp.complex128(qml.CNOT.compute_matrix()), - get_qutrit_op_as_qubits, - param, - op_type, - ) - ] + kraus_mats = [jax.lax.switch(op_type, two_qubit_ops, param)] return apply_two_qudit_operation(kraus_mats, wires, state, 2)