Skip to content

Commit

Permalink
Changed einsum indices to not index via scan abstract values
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel-Bottrill committed Jul 16, 2024
1 parent c0f5a58 commit d969cd6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
39 changes: 22 additions & 17 deletions pennylane/devices/qtcorgi_helper/apply_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
alphabet_array = np.array(list(alphabet))


def get_einsum_mapping_one_wire(wires, state, ):
def get_einsum_mapping(wires, state):
r"""Finds the indices for einsum to apply kraus operators to a mixed state
Args:
Expand All @@ -29,13 +29,16 @@ def get_einsum_mapping_one_wire(wires, state, ):
rho_dim = 2 * num_wires

# Tensor indices of the state. For each qutrit, need an index for rows *and* columns
state_indices = (..., ) + tuple(range(rho_dim)) # TODO this may need to be an input from above
# TODO this may need to be an input from above
state_indices_list = list(range(rho_dim))
state_indices = (...,) + tuple(state_indices_list)

# row indices of the quantum state affected by this operation
row_indices = tuple(wires)

# column indices are shifted by the number of wires
col_indices = tuple(w + num_wires for w in wires) # TODO replace
# TODO replace
col_indices = tuple(w + num_wires for w in wires) # TODO: Should I do an array?

# indices in einsum must be replaced with new ones
new_row_indices = tuple(range(rho_dim, rho_dim + num_ch_wires))
Expand All @@ -48,10 +51,10 @@ def get_einsum_mapping_one_wire(wires, state, ):
op_1_indices = (...,) + kraus_index + new_row_indices + row_indices
op_2_indices = (...,) + kraus_index + col_indices + new_col_indices

new_state_indices = (...,) + get_new_state_einsum_indices(
new_state_indices = get_new_state_einsum_indices(
old_indices=col_indices + row_indices,
new_indices=new_col_indices + new_row_indices,
state_indices=state_indices,
state_indices=state_indices_list,
)
# index mapping for einsum, e.g., (...0,1,2,3), (...0,1,2,3), (...0,1,2,3), (...0,1,2,3)
return op_1_indices, state_indices, op_2_indices, new_state_indices
Expand All @@ -68,28 +71,28 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices):
Returns:
tuple(int): The einsum indices of the new state
"""
# for old, new in zip(old_indices, new_indices):
# for i in old_indices:
# old_indices[i] = jax.lax.cond()
return reduce( # TODO, redo
lambda old_indices, idx_pair: old_indices[idx_pair[0]],
zip(old_indices, new_indices),
state_indices,
)
for old, new in zip(old_indices, new_indices):
for i, state_index in enumerate(state_indices): # TODO replace with jax.lax.scan
state_indices[i] = jax.lax.cond(old == state_index, lambda: new, lambda: state_index)
return (...,) + tuple(state_indices)
# return reduce( # TODO, redo
# lambda old_indices, idx_pair: old_indices[idx_pair[0]],
# zip(old_indices, new_indices),
# state_indices,
# )


QUDIT_DIM = 3


def apply_operation_einsum(kraus, wires, state, mapping_indices):
def apply_operation_einsum(kraus, wires, state):
r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the
quantum state. For a unitary gate, there is a single Kraus operator.
Args:
kraus (??): TODO
wires
state (array[complex]): Input quantum state
mapping_indices
Returns:
array[complex]: output_state
Expand All @@ -107,15 +110,17 @@ def apply_operation_einsum(kraus, wires, state, mapping_indices):
kraus = jnp.reshape(kraus, kraus_shape)
kraus_dagger = jnp.reshape(kraus_dagger, kraus_shape)

return jnp.einsum(kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices)
return jnp.einsum(
kraus, op_1_indices, state, state_indices, kraus_dagger, op_2_indices, new_state_indices
)


def get_two_qubit_unitary_matrix(param):
# TODO
pass


def get_CNOT_matrix(param):
def get_CNOT_matrix(_param):
return jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])


Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qtcorgi_helper/einsum_mapping_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices):

QUDIT_DIM = 3

print(get_einsum_mapping([0,1], np.zeros((3,3,3,3,3,3,3,3))))
print(get_einsum_mapping([0, 1], np.zeros((3, 3, 3, 3, 3, 3, 3, 3))))
2 changes: 0 additions & 2 deletions pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def get_qubit_final_state_from_initial(operations, initial_state):
"""
ops_type_indices, ops_wires, ops_param = [[], []], [[], []], []
for op in operations:

wires = op.wires()

if isinstance(op, Channel):
Expand Down Expand Up @@ -74,7 +73,6 @@ def get_qutrit_final_state_from_initial(operations, initial_state):
"""
ops_type_indices, ops_subspace, ops_wires, ops_params = [[], []], [], [[], []], [[], [], []]
for op in operations:

wires = op.wires()

if isinstance(op, Channel):
Expand Down

0 comments on commit d969cd6

Please sign in to comment.