diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 0e2d0fbfcf0..40d751958a3 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -48,6 +48,10 @@ from pennylane.wires import Wires from .._version import __version__ +import jax +import jax.numpy as jnp +from .qtcorgi_helper.apply_operations import qubit_branches + logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -55,6 +59,83 @@ ABC_ARRAY = np.array(list(ABC)) tolerance = 1e-10 +two_gates = [ + "THadamard", + "TRX_01", + "TRY_01", + "TRZ_01", + "TRX_02", + "TRY_02", + "TRZ_02", + "TRX_12", + "TRY_12", + "TRZ_12", +] + + +def get_qubit_final_state_from_initial(operations, initial_state): + """ + TODO + + Args: + operations ():TODO + initial_state ():TODO + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + ops_type_indices, ops_wires, ops_param = [[], []], [[], []], [] + + two_qubit_ops = False + for op in operations: + + wires = op.wires + + if isinstance(op, Channel): + ops_type_indices[0].append(1) + ops_type_indices[1].append( + [qml.DepolarizingChannel, qml.AmplitudeDamping, qml.BitFlip].index(type(op)) + ) + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op))) + elif len(wires) == 2: + ops_type_indices[0].append(2) + if isinstance(op, qml.CNOT): + op_index = 0 + else: + op_index = two_gates.index(op.id) + 1 + if op_index < 2: + ops_param.append(0.0) + else: + ops_param.append(op.phi) + two_qubit_ops = True + ops_type_indices[1].append(op_index) + + else: + raise ValueError("TODO") + + if len(wires) == 1: + wires = [wires[0], -1] + ops_param.append(op.parameters[0]) + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) + + ops_info = { + "type_indices": jnp.array(ops_type_indices).T, + "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], + "param": jnp.array(ops_param), + } + branches = qubit_branches[: 2 + two_qubit_ops] + + @jax.jit + def switch_function(state, op_info): + return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None + + return jax.lax.scan(switch_function, initial_state, ops_info)[0] + class DefaultMixed(QubitDevice): """Default qubit device for performing mixed-state computations in PennyLane. @@ -92,73 +173,19 @@ class DefaultMixed(QubitDevice): author = "Xanadu Inc." operations = { - "Identity", - "Snapshot", "BasisState", "QubitStateVector", "StatePrep", "QubitDensityMatrix", - "QubitUnitary", - "ControlledQubitUnitary", - "BlockEncode", - "MultiControlledX", - "DiagonalQubitUnitary", - "SpecialUnitary", - "PauliX", - "PauliY", - "PauliZ", - "MultiRZ", "Hadamard", - "S", - "T", - "SX", "CNOT", - "SWAP", - "ISWAP", - "CSWAP", - "Toffoli", - "CCZ", - "CY", - "CZ", - "CH", - "PhaseShift", - "PCPhase", - "ControlledPhaseShift", - "CPhaseShift00", - "CPhaseShift01", - "CPhaseShift10", + "QubitUnitary", "RX", "RY", "RZ", - "Rot", - "CRX", - "CRY", - "CRZ", - "CRot", "AmplitudeDamping", - "GeneralizedAmplitudeDamping", - "PhaseDamping", "DepolarizingChannel", "BitFlip", - "PhaseFlip", - "PauliError", - "ResetError", - "QubitChannel", - "SingleExcitation", - "SingleExcitationPlus", - "SingleExcitationMinus", - "DoubleExcitation", - "DoubleExcitationPlus", - "DoubleExcitationMinus", - "QubitCarry", - "QubitSum", - "OrbitalRotation", - "FermionicSWAP", - "QFT", - "ThermalRelaxationError", - "ECR", - "ParametrizedEvolution", - "GlobalPhase", } _reshape = staticmethod(qnp.reshape) @@ -781,9 +808,18 @@ def apply(self, operations, rotations=None, **kwargs): f"on a {self.short_name} device." ) - for operation in operations: - self._apply_operation(operation) - + prep = False + if len(operations) > 0: + if ( + isinstance(operations[0], StatePrep) + or isinstance(operations[0], BasisState) + or isinstance(operations[0], QubitDensityMatrix) + ): + prep = True + self._apply_operation(operation) + + self._state = jnp.complex128(self._state) + self._state = get_qubit_final_state_from_initial(operations[prep:], self._state) # store the pre-rotated state self._pre_rotated_state = self._state diff --git a/pennylane/devices/default_qutrit_mixed.py b/pennylane/devices/default_qutrit_mixed.py index bf8ae1eea7d..781b19c00a3 100644 --- a/pennylane/devices/default_qutrit_mixed.py +++ b/pennylane/devices/default_qutrit_mixed.py @@ -72,7 +72,15 @@ def observable_stopping_condition(obs: qml.operation.Operator) -> bool: def stopping_condition(op: qml.operation.Operator) -> bool: """Specify whether an Operator object is supported by the device.""" - expected_set = DefaultQutrit.operations | {"Snapshot"} | channels + operations = { + "TAdd", + "Adjoint(TAdd)", + "THadamard", + "TRX", + "TRY", + "TRZ", + } + expected_set = operations | channels return op.name in expected_set diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 819e375df54..22b7d229f38 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -357,7 +357,6 @@ def decomposer(op): if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]): return (tape,), null_postprocessing try: - new_ops = [ final_op for op in tape.operations[len(prep_op) :] diff --git a/pennylane/devices/qtcorgi_helper/__init__.py b/pennylane/devices/qtcorgi_helper/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py new file mode 100644 index 00000000000..5b80bb32d22 --- /dev/null +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -0,0 +1,401 @@ +import time +import jax +import jax.numpy as jnp +from jax.lax import scan + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platforms", "cpu") +import pennylane as qml +from string import ascii_letters as alphabet + +import numpy as np +from functools import partial, reduce + +alphabet_array = np.array(list(alphabet)) +stack_last = partial(qml.math.stack, axis=-1) + + +@partial(jax.jit, static_argnames=["reverse"]) +def swap_axes(op, start, fin, reverse=False): + axes = jnp.arange(op.ndim) + for s, f in reversed(list(zip(start, fin))) if reverse else zip(start, fin): + axes = axes.at[f].set(s) + axes = axes.at[s].set(f) + indices = jnp.mgrid[tuple(slice(s) for s in op.shape)] # TODO can do + indices = indices[axes] + return op[tuple(indices[i] for i in range(indices.shape[0]))] + + +def get_einsum_mapping(wires, state): + r"""Finds the indices for einsum to apply kraus operators to a mixed state + + Args: + wires + state (array[complex]): Input quantum state + + Returns: + tuple(tuple(int)): Indices mapping that defines the einsum + """ + num_ch_wires = len(wires) + num_wires = int(len(qml.math.shape(state)) / 2) + rho_dim = 2 * num_wires + + # Tensor indices of the state. For each qutrit, need an index for rows *and* columns + # TODO this may need to be an input from above + state_indices_list = list(range(rho_dim)) + state_indices = (...,) + tuple(state_indices_list) + + # row indices of the quantum state affected by this operation + row_indices = tuple(wires) + + # column indices are shifted by the number of wires + # TODO replace + col_indices = tuple(w + num_wires for w in wires) # TODO: Should I do an array? + + # indices in einsum must be replaced with new ones + new_row_indices = tuple(range(rho_dim, rho_dim + num_ch_wires)) + new_col_indices = tuple(range(rho_dim + num_ch_wires, rho_dim + 2 * num_ch_wires)) + + # index for summation over Kraus operators + kraus_index = (rho_dim + 2 * num_ch_wires,) + + # apply mapping function + op_1_indices = (...,) + kraus_index + new_row_indices + row_indices + op_2_indices = (...,) + kraus_index + col_indices + new_col_indices + + new_state_indices = get_new_state_einsum_indices( + old_indices=col_indices + row_indices, + new_indices=new_col_indices + new_row_indices, + state_indices=state_indices_list, + ) + # index mapping for einsum, e.g., (...0,1,2,3), (...0,1,2,3), (...0,1,2,3), (...0,1,2,3) + return op_1_indices, state_indices, op_2_indices, new_state_indices + + +def get_new_state_einsum_indices(old_indices, new_indices, state_indices): + """Retrieves the einsum indices string for the new state + + Args: + old_indices tuple(int): indices that are summed + new_indices tuple(int): indices that must be replaced with sums + state_indices tuple(int): indices of the original state + + Returns: + tuple(int): The einsum indices of the new state + """ + for old, new in zip(old_indices, new_indices): + state_indices[old] = new + return (...,) + tuple(state_indices) + + +# @jax.jit +def apply_operation_einsum(kraus, swap_inds, state, qudit_dim, num_wires): + state = swap_axes(state, *swap_inds) + + # Shape kraus operators + kraus_shape = [len(kraus)] + ([qudit_dim] * num_wires * 2) + + kraus = jnp.stack(kraus) + kraus_dagger = jnp.conj(jnp.stack(jnp.moveaxis(kraus, source=-1, destination=-2))) + + kraus = jnp.reshape(kraus, kraus_shape) + kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape) + op_1_indices, state_indices, op_2_indices, new_state_indices = get_einsum_mapping( + list(range(num_wires)), state + ) + state = jnp.einsum( + kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices + ) + return swap_axes(state, *swap_inds, reverse=True) + + +@partial(jax.jit, static_argnames=["qudit_dim"]) +def apply_single_qudit_operation(kraus, wire, state, qudit_dim): + num_wires = state.ndim // 2 + swap_inds = (wire, wire + num_wires), (0, num_wires) + return apply_operation_einsum(kraus, swap_inds, state, qudit_dim, 1) + + +@partial(jax.jit, static_argnames=["qudit_dim"]) +def apply_two_qudit_operation(kraus, wires, state, qudit_dim): + num_wires = state.ndim // 2 + + # wire_choice = (wires[0] == 1 * wires[1] == 0) + 2 * (wires[0] == 1 * wires[1] != 0) + 3 * (wires[0] != 1 * wires[1] == 0) + @jax.jit + def apply_two_qudit_regular(): + start = (wires[0], wires[0] + num_wires, wires[1], wires[1] + num_wires) + fin = jax.lax.cond( + wires[0] < wires[1], + lambda: (0, num_wires, 1, 1 + num_wires), + lambda: (1, 1 + num_wires, 0, num_wires), + ) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + @jax.jit + def apply_two_qudit_10(): + start = (1, 1 + num_wires) + fin = (0, num_wires) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + @jax.jit + def apply_two_qudit_1x(): + start = (1, 1 + num_wires, wires[1], wires[1] + num_wires) + fin = (0, num_wires, 1, 1 + num_wires) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + @jax.jit + def apply_two_qudit_x0(): + start = (0, num_wires, wires[0], wires[0] + num_wires) + fin = (1, 1 + num_wires, 0, num_wires) + return apply_operation_einsum(kraus, (start, fin), state, qudit_dim, 2) + + return jax.lax.cond( + wires[0] == 1, + lambda w1: jax.lax.cond(w1 == 0, apply_two_qudit_10, apply_two_qudit_1x), + lambda w1: jax.lax.cond(w1 == 0, apply_two_qudit_x0, apply_two_qudit_regular), + wires[1], + ) + + +single_qubit_ops = [ + qml.RX.compute_matrix, + qml.RY.compute_matrix, + qml.RZ.compute_matrix, + lambda _param: jnp.complex128(qml.Hadamard.compute_matrix()), +] + + +def get_qubit_TRX_matrix_func(subspace): + def qubit_TRX_matrix(theta): + c = qml.math.cos(theta / 2) + s = qml.math.sin(theta / 2) + + # The following avoids casting an imaginary quantity to reals when backpropagating + c = (1 + 0j) * c + js = -1j * s + one = qml.math.ones_like(c) + z = qml.math.zeros_like(c) + + diags = [one, one, one] + diags[subspace[0]] = c + diags[subspace[1]] = c + + off_diags = [z, z, z] + off_diags[qml.math.sum(subspace) - 1] = js + + return qml.math.stack( + [ + stack_last([diags[0], off_diags[0], off_diags[1], z]), + stack_last([off_diags[0], diags[1], off_diags[2], z]), + stack_last([off_diags[1], off_diags[2], diags[2], z]), + stack_last([z, z, z, one]), + ], + axis=-2, + ) + + return qubit_TRX_matrix + + +def get_qubit_TRY_matrix_func(subspace): + def qubit_TRY_matrix(theta): + c = qml.math.cos(theta / 2) + s = qml.math.sin(theta / 2) + + # The following avoids casting an imaginary quantity to reals when backpropagating + c = (1 + 0j) * c + s = (1 + 0j) * s + one = qml.math.ones_like(c) + z = qml.math.zeros_like(c) + + diags = [one, one, one] + diags[subspace[0]] = c + diags[subspace[1]] = c + + off_diags = [z, z, z] + off_diags[qml.math.sum(subspace) - 1] = s + + return qml.math.stack( + [ + stack_last([diags[0], -off_diags[0], -off_diags[1], z]), + stack_last([off_diags[0], diags[1], -off_diags[2], z]), + stack_last([off_diags[1], off_diags[2], diags[2], z]), + stack_last([z, z, z, one]), + ], + axis=-2, + ) + + return qubit_TRY_matrix + + +def get_qubit_TRZ_matrix_func(subspace): + def qubit_TRZ_matrix(theta): + p = qml.math.exp(-1j * theta / 2) + one = qml.math.ones_like(p) + z = qml.math.zeros_like(p) + + diags = [one, one, one] + diags[subspace[0]] = p + diags[subspace[1]] = qml.math.conj(p) + + return qml.math.stack( + [ + stack_last([diags[0], z, z, z]), + stack_last([z, diags[1], z, z]), + stack_last([z, z, diags[2], z]), + stack_last([z, z, z, one]), + ], + axis=-2, + ) + + return qubit_TRZ_matrix + + +OMEGA = np.exp(2 * np.pi * 1j / 3) + +two_qubit_ops = [ + lambda _param: jnp.complex128(qml.CNOT.compute_matrix()), + lambda _param: (-1j / np.sqrt(3)) + * np.array([[1, 1, 1, 0], [1, OMEGA, OMEGA**2, 0], [1, OMEGA**2, OMEGA, 0], [0, 0, 0, 1]]), + get_qubit_TRX_matrix_func([0, 1]), + get_qubit_TRY_matrix_func([0, 1]), + get_qubit_TRZ_matrix_func([0, 1]), + get_qubit_TRX_matrix_func([0, 2]), + get_qubit_TRY_matrix_func([0, 2]), + get_qubit_TRZ_matrix_func([0, 2]), + get_qubit_TRX_matrix_func([1, 2]), + get_qubit_TRY_matrix_func([1, 2]), + get_qubit_TRZ_matrix_func([1, 2]), +] + + +@jax.jit +def get_qutrit_op_as_qubits(param, op_type): + new_mat = jnp.eye(4, dtype=jnp.complex128) + return new_mat.at[:3, :3].set(jax.lax.switch(op_type - 1, qubits_qutrit_ops, param)) + + +@jax.jit +def apply_single_qubit_unitary(state, op_info): + wire, param = op_info["wires"][0], op_info["param"] + kraus_mat = [jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param)] + return apply_single_qudit_operation(kraus_mat, wire, state, 2) + + +@jax.jit +def apply_two_qubit_unitary(state, op_info): + wires, param = op_info["wires"], op_info["param"] + op_type = op_info["type_indices"][1] + kraus_mats = [jax.lax.switch(op_type, two_qubit_ops, param)] + return apply_two_qudit_operation(kraus_mats, wires, state, 2) + + +@jax.jit +def apply_qubit_depolarizing_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["param"] + kraus_mats = qml.DepolarizingChannel.compute_kraus_matrices(param) + return apply_single_qudit_operation(kraus_mats, wire, state, 2) + + +@jax.jit +def apply_qubit_flipping_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["param"] + kraus_mats = jax.lax.cond( + op_info["type_indices"][1] == 1, + qml.AmplitudeDamping.compute_kraus_matrices, + qml.BitFlip.compute_kraus_matrices, + param, + ) + return apply_single_qudit_operation(kraus_mats, wire, state, 2) + + +@jax.jit +def apply_single_qubit_channel(state, op_info): + return jax.lax.cond( + op_info["type_indices"][1] == 0, + apply_qubit_depolarizing_channel, + apply_qubit_flipping_channel, + state, + op_info, + ) + + +qubit_branches = [apply_single_qubit_unitary, apply_single_qubit_channel, apply_two_qubit_unitary] + +single_qutrit_ops_subspace_01 = [ + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]), + qml.TRX.compute_matrix, + qml.TRY.compute_matrix, + qml.TRZ.compute_matrix, +] +single_qutrit_ops_subspace_02 = [ + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), + partial(qml.TRX.compute_matrix, subspace=[0, 2]), + partial(qml.TRY.compute_matrix, subspace=[0, 2]), + partial(qml.TRZ.compute_matrix, subspace=[0, 2]), +] +single_qutrit_ops_subspace_12 = [ + lambda _param: qml.THadamard.compute_matrix(subspace=[1, 2]), + partial(qml.TRX.compute_matrix, subspace=[1, 2]), + partial(qml.TRY.compute_matrix, subspace=[1, 2]), + partial(qml.TRZ.compute_matrix, subspace=[1, 2]), +] +single_qutrit_ops = [ + lambda _op_type, _param: qml.THadamard.compute_matrix(), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_01, param), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_02, param), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_12, param), +] + +two_qutrits_ops = [qml.TAdd.compute_matrix, lambda: jnp.conj(qml.TAdd.compute_matrix().T)] + + +@jax.jit +def apply_single_qutrit_unitary(state, op_info): + wire, param = op_info["wires"][0], op_info["params"][0] + subspace_index, op_type = op_info["wires"][1], op_info["type_indices"][1] + kraus_mats = [jax.lax.switch(subspace_index, single_qutrit_ops, op_type, param)] + return apply_single_qudit_operation(kraus_mats, wire, state, 3) + + +@jax.jit +def apply_two_qutrit_unitary(state, op_info): + wires = op_info["wires"] + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], two_qutrits_ops)] + return apply_two_qudit_operation(kraus_mats, wires, state, 3) + + +@jax.jit +def apply_qutrit_depolarizing_channel(state, op_info): + wire, param = op_info["wires"][0], op_info["params"][0] + kraus_mats = qml.QutritDepolarizingChannel.compute_kraus_matrices(param) + return apply_single_qudit_operation(kraus_mats, wire, state, 3) + + +@jax.jit +def apply_qutrit_subspace_channel(state, op_info): + wire, params = op_info["wires"][0], op_info["params"] + kraus_mats = jax.lax.cond( + op_info["type_indices"][1] == 1, + qml.QutritAmplitudeDamping.compute_kraus_matrices, + qml.TritFlip.compute_kraus_matrices, + *params + ) + return apply_single_qudit_operation(kraus_mats, wire, state, 3) + + +@jax.jit +def apply_single_qutrit_channel(state, op_info): + return jax.lax.cond( + op_info["type_indices"][1] == 0, + apply_qutrit_depolarizing_channel, + apply_qutrit_subspace_channel, + state, + op_info, + ) + + +qutrit_branches = [ + apply_single_qutrit_unitary, + apply_single_qutrit_channel, + apply_two_qutrit_unitary, +] diff --git a/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py b/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py new file mode 100644 index 00000000000..965499c8478 --- /dev/null +++ b/pennylane/devices/qtcorgi_helper/einsum_mapping_old.py @@ -0,0 +1,83 @@ +import time +import jax +import jax.numpy as jnp +from jax.lax import scan + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platforms", "cpu") +import pennylane as qml +from string import ascii_letters as alphabet + +import numpy as np +from functools import partial, reduce + +alphabet_array = np.array(list(alphabet)) + + +def get_einsum_mapping(wires, state): + r"""Finds the indices for einsum to apply kraus operators to a mixed state + + Args: + wires + state (array[complex]): Input quantum state + + Returns: + str: Indices mapping that defines the einsum + """ + num_ch_wires = len(wires) + num_wires = int(len(qml.math.shape(state)) / 2) + rho_dim = 2 * num_wires + + # Tensor indices of the state. For each qutrit, need an index for rows *and* columns + state_indices = alphabet[:rho_dim] + + # row indices of the quantum state affected by this operation + row_wires_list = wires + row_indices = "".join(alphabet_array[row_wires_list].tolist()) + + # column indices are shifted by the number of wires + col_wires_list = [w + num_wires for w in row_wires_list] + col_indices = "".join(alphabet_array[col_wires_list].tolist()) + + # indices in einsum must be replaced with new ones + new_row_indices = alphabet[rho_dim : rho_dim + num_ch_wires] + new_col_indices = alphabet[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires] + + # index for summation over Kraus operators + kraus_index = alphabet[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1] + print(kraus_index) + + # apply mapping function + op_1_indices = f"{kraus_index}{new_row_indices}{row_indices}" + op_2_indices = f"{kraus_index}{col_indices}{new_col_indices}" + + new_state_indices = get_new_state_einsum_indices( + old_indices=col_indices + row_indices, + new_indices=new_col_indices + new_row_indices, + state_indices=state_indices, + ) + # index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef' + return f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}" + + +def get_new_state_einsum_indices(old_indices, new_indices, state_indices): + """Retrieves the einsum indices string for the new state + + Args: + old_indices (str): indices that are summed + new_indices (str): indices that must be replaced with sums + state_indices (str): indices of the original state + + Returns: + str: The einsum indices of the new state + """ + return reduce( + lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]), + zip(old_indices, new_indices), + state_indices, + ) + + +QUDIT_DIM = 3 + +print(get_einsum_mapping([0, 1], np.zeros((3, 3, 3, 3, 3, 3, 3, 3)))) diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index 2d87d3e504e..7c3f5f6a805 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -132,7 +132,6 @@ def _get_num_wire_groups_for_expval_H(obs): def _get_num_executions_for_sum(obs): - if obs.grouping_indices: return len(obs.grouping_indices) diff --git a/pennylane/devices/qutrit_mixed/__init__.py b/pennylane/devices/qutrit_mixed/__init__.py index 192b5a1b65a..e171789454e 100644 --- a/pennylane/devices/qutrit_mixed/__init__.py +++ b/pennylane/devices/qutrit_mixed/__init__.py @@ -34,4 +34,4 @@ from .initialize_state import create_initial_state from .measure import measure from .sampling import sample_state, measure_with_samples -from .simulate import simulate, get_final_state, measure_final_state +from .simulate import simulate, measure_final_state diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 69fd4c4c8e4..5594870346b 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -18,11 +18,12 @@ import pennylane as qml from pennylane.typing import Result -from .apply_operation import apply_operation from .initialize_state import create_initial_state from .measure import measure from .sampling import measure_with_samples -from .utils import QUDIT_DIM +from ..qtcorgi_helper.apply_operations import qutrit_branches +import jax +import jax.numpy as jnp INTERFACE_TO_LIKE = { # map interfaces known by autoray to themselves @@ -45,57 +46,70 @@ } -def get_final_state(circuit, debugger=None, interface=None, **kwargs): +def get_qutrit_final_state_from_initial(operations, initial_state): """ - Get the final state that results from executing the given quantum script. - - This is an internal function that will be called by ``default.qutrit.mixed``. + TODO Args: - circuit (.QuantumScript): The single circuit to simulate - debugger (._Debugger): The debugger to use - interface (str): The machine learning interface to create the initial state with + operations ():TODO + initial_state ():TODO Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and whether the state has a batch dimension. """ - circuit = circuit.map_to_standard_wires() - - prep = None - if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): - prep = circuit[0] - - state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface]) - - # initial state is batched only if the state preparation (if it exists) is batched - is_state_batched = bool(prep and prep.batch_size is not None) - for op in circuit.operations[bool(prep) :]: - state = apply_operation( - op, - state, - is_state_batched=is_state_batched, - debugger=debugger, - tape_shots=circuit.shots, - **kwargs, - ) - - # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim - is_state_batched = is_state_batched or op.batch_size is not None - - num_operated_wires = len(circuit.op_wires) - for i in range(len(circuit.wires) - num_operated_wires): - # If any measured wires are not operated on, we pad the density matrix with zeros. - # We know they belong at the end because the circuit is in standard wire-order - # Since it is a dm, we must pad it with 0s on the last row and last column - current_axis = num_operated_wires + i + is_state_batched - state = qml.math.stack( - ([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=current_axis - ) - state = qml.math.stack(([state] + [qml.math.zeros_like(state)] * (QUDIT_DIM - 1)), axis=-1) - - return state, is_state_batched + ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] + two_qutrit_ops = False + for op in operations: + wires = op.wires + + if isinstance(op, qml.operation.Channel): + ops_type_indices[0].append(1) + ops_type_indices[1].append( + [qml.QutritDepolarizingChannel, qml.QutritAmplitudeDamping, qml.TritFlip].index( + type(op) + ) + ) + params = op.parameters + ([0] * (3 - op.num_params)) + wires = [wires[0], -1] + elif len(wires) == 1: + ops_type_indices[0].append(0) + ops_type_indices[1].append([qml.THadamard, qml.TRX, qml.TRY, qml.TRZ].index(type(op))) + subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) + if ops_type_indices[1][-1] == 0: + params = [0.0, 0.0, 0.0] + else: + params = list(op.parameters) + [0.0, 0.0] + wires = [wires[0], subspace_index] + elif len(wires) == 2: + ops_type_indices[0].append(2) + ops_type_indices[1].append(0 if isinstance(op, qml.TAdd) else 1) + params = [0.0, 0.0, 0.0] + two_qutrit_ops = True + else: + raise ValueError("TODO") + ops_params[0].append(params[0]) + ops_params[1].append(params[1]) + ops_params[2].append(params[2]) + + if len(wires) == 1: + wires = [wires[0], -1] + ops_wires[0].append(wires[0]) + ops_wires[1].append(wires[1]) + + ops_info = { + "type_indices": jnp.array(ops_type_indices).T, + "wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])], + "params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])], + } + branches = qutrit_branches[: 2 + two_qutrit_ops] + + @jax.jit + def switch_function(state, op_info): + return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None + + return jax.lax.scan(switch_function, initial_state, ops_info)[0] def measure_final_state( # pylint: disable=too-many-arguments @@ -157,6 +171,29 @@ def measure_final_state( # pylint: disable=too-many-arguments return results +def get_final_state_qutrit(circuit, **kwargs): + """ + TODO + + Args: + circuit (.QuantumScript): The single circuit to simulate + + Returns: + Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and + whether the state has a batch dimension. + + """ + + circuit = circuit.map_to_standard_wires() + + prep = None + if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): + prep = circuit[0] + + state = jnp.complex128(create_initial_state(sorted(circuit.op_wires), prep, like="jax")) + return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state) + + def simulate( # pylint: disable=too-many-arguments circuit: qml.tape.QuantumScript, rng=None, @@ -195,14 +232,9 @@ def simulate( # pylint: disable=too-many-arguments tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True)) """ - state, is_state_batched = get_final_state( + state = get_final_state_qutrit( circuit, debugger=debugger, interface=interface, rng=rng, prng_key=prng_key ) return measure_final_state( - circuit, - state, - is_state_batched, - rng=rng, - prng_key=prng_key, - readout_errors=readout_errors, + circuit, state, False, rng=rng, prng_key=prng_key, readout_errors=readout_errors )