Skip to content

Commit

Permalink
Got to einsum, not working...
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel-Bottrill committed Jul 16, 2024
1 parent 2cb5985 commit aac0ed8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
33 changes: 16 additions & 17 deletions pennylane/devices/qtcorgi_helper/apply_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_CNOT_matrix(_param):


def apply_single_qubit_unitary(state, op_info):
wires, param = op_info["wires"][:0], op_info["params"][0]
wires, param = op_info["wires"][:1], op_info["params"][0]
kraus_mat = jax.lax.switch(op_info["type_indices"][1], single_qubit_ops, param)
return apply_operation_einsum(kraus_mat, wires, state)

Expand All @@ -146,37 +146,36 @@ def apply_two_qubit_unitary(state, op_info):


def apply_single_qubit_channel(state, op_info):
wires, param = op_info["wires"][:0], op_info["params"][0]
wires, param = op_info["wires"][:1], op_info["params"][0]
kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qubit_channels, param)]
return apply_operation_einsum(kraus_mats, wires, state)


qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel]

single_qutrit_ops_subspace_0 = [None, None, None, lambda _param: qml.THadamard.compute_matrix()]
single_qutrit_ops_subspace_1 = [
single_qutrit_ops_subspace_01 = [
qml.TRX.compute_matrix,
qml.TRY.compute_matrix,
qml.TRZ.compute_matrix,
partial(qml.THadamard.compute_matrix, subspace=[0, 1]),
lambda _param: qml.THadamard.compute_matrix(subspace=[0, 1]),
]
single_qutrit_ops_subspace_2 = [
single_qutrit_ops_subspace_02 = [
partial(qml.TRX.compute_matrix, subspace=[0, 2]),
partial(qml.TRY.compute_matrix, subspace=[0, 2]),
partial(qml.TRZ.compute_matrix, subspace=[0, 2]),
partial(qml.THadamard.compute_matrix, subspace=[0, 2]),
lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]),
]
single_qutrit_ops_subspace_3 = [
single_qutrit_ops_subspace_12 = [
partial(qml.TRX.compute_matrix, subspace=[1, 2]),
partial(qml.TRY.compute_matrix, subspace=[1, 2]),
partial(qml.TRZ.compute_matrix, subspace=[1, 2]),
partial(qml.THadamard.compute_matrix, subspace=[0, 2]),
lambda _param: qml.THadamard.compute_matrix(subspace=[0, 2]),
]
single_qutrit_ops = [
single_qutrit_ops_subspace_0,
single_qutrit_ops_subspace_1,
single_qutrit_ops_subspace_2,
single_qutrit_ops_subspace_3,
lambda _op_type, _param: qml.THadamard.compute_matrix(),
lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_01, param),
lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_02, param),
lambda op_type, param: jax.lax.switch(op_type, single_qutrit_ops_subspace_12, param),
]

two_qutrits_ops = [qml.TAdd.compute_matrix, qml.adjoint(qml.TAdd).compute_matrix]
Expand All @@ -189,9 +188,9 @@ def apply_single_qubit_channel(state, op_info):


def apply_single_qutrit_unitary(state, op_info):
wires, param, subspace_index = op_info["wires"][:0], op_info["params"][0], op_info["params"][1]
mat_funcs = jax.lax.switch(subspace_index, single_qutrit_ops, param)
kraus_mats = [jax.lax.switch(op_info["type_indices"][1], mat_funcs, param)]
wires, param = op_info["wires"][:1], op_info["params"][0]
subspace_index, op_type = op_info["params"][1], op_info["type_indices"][1]
kraus_mats = [jax.lax.switch(subspace_index, single_qutrit_ops, op_type, param)]
return apply_operation_einsum(kraus_mats, wires, state)


Expand All @@ -202,7 +201,7 @@ def apply_two_qutrit_unitary(state, op_info):


def apply_single_qutrit_channel(state, op_info):
wires, params = op_info["wires"][:0], op_info["params"] # TODO qutrit channels take 3 params
wires, params = op_info["wires"][:1], op_info["params"] # TODO qutrit channels take 3 params
kraus_mats = [jax.lax.switch(op_info["type_indices"][1], single_qutrit_channels, *params)]
return apply_operation_einsum(kraus_mats, wires, state)

Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qutrit_mixed/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_qutrit_final_state_from_initial(operations, initial_state):
elif len(wires) == 1:
ops_type_indices[0].append(0)
ops_type_indices[1].append([qml.TRX, qml.TRY, qml.TRZ, qml.THadamard].index(type(op)))
subspace_index = op.subspace.index([None, (0, 1), (0, 2), (1, 2)])
subspace_index = [None, (0, 1), (0, 2), (1, 2)].index(op.subspace)
if ops_type_indices[1][-1] == 3:
params = [0, subspace_index, 0]
else:
Expand Down

0 comments on commit aac0ed8

Please sign in to comment.