diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 47874dd4976..71c9a42edf5 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -153,16 +153,30 @@ def apply_single_qubit_channel(state, op_info): qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] - -single_qutrit_ops = [ +single_qutrit_ops_subspace_0 = [None, None, None, lambda _param: qml.THadamard.compute_matrix()] +single_qutrit_ops_subspace_1 = [ qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix, - lambda params: ( - qml.THadamard.compute_matrix() - if params[1] != 0 - else qml.THadamard.compute_matrix(subspace=params[1:]) - ), + partial(qml.THadamard.compute_matrix, subspace=[0, 1]), +] +single_qutrit_ops_subspace_2 = [ + partial(qml.TRX.compute_matrix, subspace=[0, 2]), + partial(qml.TRY.compute_matrix, subspace=[0, 2]), + partial(qml.TRZ.compute_matrix, subspace=[0, 2]), + partial(qml.THadamard.compute_matrix, subspace=[0, 2]), +] +single_qutrit_ops_subspace_3 = [ + partial(qml.TRX.compute_matrix, subspace=[1, 2]), + partial(qml.TRY.compute_matrix, subspace=[1, 2]), + partial(qml.TRZ.compute_matrix, subspace=[1, 2]), + partial(qml.THadamard.compute_matrix, subspace=[0, 2]), +] +single_qutrit_ops = [ + single_qutrit_ops_subspace_0, + single_qutrit_ops_subspace_1, + single_qutrit_ops_subspace_2, + single_qutrit_ops_subspace_3, ] two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix] @@ -175,8 +189,9 @@ def apply_single_qubit_channel(state, op_info): def apply_single_qutrit_unitary(state, op_info): - wires, param = op_info["wires"][:0], op_info["params"][0] - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_ops, param)] + wires, param, subspace_index = op_info["wires"][:0], op_info["params"][0], op_info["params"][1] + mat_funcs = jax.lax.switch(subspace_index, single_qutrit_ops, param) + kraus_mats = [jax.lax.switch(op_info["type_indices"][1], mat_funcs, param)] return apply_operation_einsum(kraus_mats, wires, state) diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index 13811cff468..edbf74fb6df 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -61,8 +61,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): """ ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []] for op in operations: - - wires = op.wires() + wires = op.wires if isinstance(op, qml.operation.Channel): ops_type_indices[0].append(2) @@ -75,10 +74,11 @@ def get_qutrit_final_state_from_initial(operations, initial_state): elif len(wires) == 1: ops_type_indices[0].append(0) ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) + subspace_index = op.subspace.index([None, (0, 1), (0, 2), (1, 2)]) if ops_type_indices[1][-1] == 3: - params = [0] + list(op.subspace) if op.subspace is not None else [0, 0] + params = [0, subspace_index, 0] else: - params = list(op.params) + list(op.subspace) + params = list(op.parameters) + [subspace_index, 0] elif len(wires) == 2: ops_type_indices[0].append(1) ops_type_indices[1].append(0) # Assume always TAdd