diff --git a/pennylane/devices/qtcorgi_helper/apply_operations.py b/pennylane/devices/qtcorgi_helper/apply_operations.py index 71c9a42edf5..7a16f054efa 100644 --- a/pennylane/devices/qtcorgi_helper/apply_operations.py +++ b/pennylane/devices/qtcorgi_helper/apply_operations.py @@ -134,7 +134,7 @@ def get_CNOT_matrix(_param): def apply_single_qubit_unitary(state, op_info): - wires, param = op_info["wires"][:0], op_info["params"][0] + wires, param = op_info["wires"][:1], op_info["params"][0] kraus_mat = jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param) return apply_operation_einsum(kraus_mat, wires, state) @@ -146,37 +146,36 @@ def apply_two_qubit_unitary(state, op_info): def apply_single_qubit_channel(state, op_info): - wires, param = op_info["wires"][:0], op_info["params"][0] + wires, param = op_info["wires"][:1], op_info["params"][0] kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qubit_channels, param)] return apply_operation_einsum(kraus_mats, wires, state) qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel] -single_qutrit_ops_subspace_0 = [None, None, None, lambda _param: qml.THadamard.compute_matrix()] -single_qutrit_ops_subspace_1 = [ +single_qutrit_ops_subspace_01 = [ qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix, - partial(qml.THadamard.compute_matrix, subspace=[0, 1]), + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]), ] -single_qutrit_ops_subspace_2 = [ +single_qutrit_ops_subspace_02 = [ partial(qml.TRX.compute_matrix, subspace=[0, 2]), partial(qml.TRY.compute_matrix, subspace=[0, 2]), partial(qml.TRZ.compute_matrix, subspace=[0, 2]), - partial(qml.THadamard.compute_matrix, subspace=[0, 2]), + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), ] -single_qutrit_ops_subspace_3 = [ +single_qutrit_ops_subspace_12 = [ partial(qml.TRX.compute_matrix, subspace=[1, 2]), partial(qml.TRY.compute_matrix, subspace=[1, 2]), partial(qml.TRZ.compute_matrix, subspace=[1, 2]), - partial(qml.THadamard.compute_matrix, subspace=[0, 2]), + lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]), ] single_qutrit_ops = [ - single_qutrit_ops_subspace_0, - single_qutrit_ops_subspace_1, - single_qutrit_ops_subspace_2, - single_qutrit_ops_subspace_3, + lambda _op_type, _param: qml.THadamard.compute_matrix(), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_01, param), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_02, param), + lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_12, param), ] two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix] @@ -189,9 +188,9 @@ def apply_single_qubit_channel(state, op_info): def apply_single_qutrit_unitary(state, op_info): - wires, param, subspace_index = op_info["wires"][:0], op_info["params"][0], op_info["params"][1] - mat_funcs = jax.lax.switch(subspace_index, single_qutrit_ops, param) - kraus_mats = [jax.lax.switch(op_info["type_indices"][1], mat_funcs, param)] + wires, param = op_info["wires"][:1], op_info["params"][0] + subspace_index, op_type = op_info["params"][1], op_info["type_indices"][1] + kraus_mats = [jax.lax.switch(subspace_index, single_qutrit_ops, op_type, param)] return apply_operation_einsum(kraus_mats, wires, state) @@ -202,7 +201,7 @@ def apply_two_qutrit_unitary(state, op_info): def apply_single_qutrit_channel(state, op_info): - wires, params = op_info["wires"][:0], op_info["params"] # TODO qutrit channels take 3 params + wires, params = op_info["wires"][:1], op_info["params"] # TODO qutrit channels take 3 params kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_channels, *params)] return apply_operation_einsum(kraus_mats, wires, state) diff --git a/pennylane/devices/qutrit_mixed/simulate.py b/pennylane/devices/qutrit_mixed/simulate.py index edbf74fb6df..3fcabca670d 100644 --- a/pennylane/devices/qutrit_mixed/simulate.py +++ b/pennylane/devices/qutrit_mixed/simulate.py @@ -74,7 +74,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state): elif len(wires) == 1: ops_type_indices[0].append(0) ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op))) - subspace_index = op.subspace.index([None, (0, 1), (0, 2), (1, 2)]) + subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace) if ops_type_indices[1][-1] == 3: params = [0, subspace_index, 0] else: