Skip to content

Commit

Permalink
Linked new jittable simulation to old device
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel-Bottrill committed Jul 3, 2024
1 parent 2670761 commit f1679f4
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 179 deletions.
15 changes: 12 additions & 3 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pennylane.wires import Wires

from .._version import __version__
from .qtcorgi_helper.qtcorgi_simulator import get_qubit_final_state_from_initial

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -781,9 +782,17 @@ def apply(self, operations, rotations=None, **kwargs):
f"on a {self.short_name} device."
)

for operation in operations:
self._apply_operation(operation)

prep = False
if len(operations) > 0:
if (
isinstance(operations[0], StatePrep)
or isinstance(operations[0], BasisState)
or isinstance(operations[0], QubitDensityMatrix)
):
prep = True
self._apply_operation(operation)

self._state = get_qubit_final_state_from_initial(operations[prep:], self._state)
# store the pre-rotated state
self._pre_rotated_state = self._state

Expand Down
75 changes: 36 additions & 39 deletions pennylane/devices/qtcorgi_helper/apply_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import jax
import jax.numpy as jnp
from jax.lax import scan

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")
import pennylane as qml
from string import ascii_letters as alphabet

import numpy as np

alphabet_array = np.array(list(alphabet))


def get_einsum_mapping(wires, state):
r"""Finds the indices for einsum to apply kraus operators to a mixed state
Expand Down Expand Up @@ -52,9 +55,8 @@ def get_einsum_mapping(wires, state):
state_indices=state_indices,
)
# index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef'
return (
f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}"
)
return f"...{op_1_indices},...{state_indices},...{op_2_indices}->...{new_state_indices}"


def get_new_state_einsum_indices(old_indices, new_indices, state_indices):
"""Retrieves the einsum indices string for the new state
Expand All @@ -73,7 +75,10 @@ def get_new_state_einsum_indices(old_indices, new_indices, state_indices):
state_indices,
)


QUDIT_DIM = 3


def apply_operation_einsum(kraus, wires, state):
r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the
quantum state. For a unitary gate, there is a single Kraus operator.
Expand Down Expand Up @@ -108,15 +113,17 @@ def get_two_qubit_unitary_matrix():


def get_CNOT_matrix(params):
return jnp.array([[1,0,0,0],
[0,1,0,0],
[0,0,0,1],
[0,0,1,0]])
return jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])


single_qubit_ops = [qml.RX.compute_matrix, qml.RY.compute_matrix, qml.RZ.compute_matrix]
two_qubit_ops = [get_CNOT_matrix, get_two_qubit_unitary_matrix]
single_qubit_channels = [qml.DepolarizingChannel.compute_kraus_matrices, qml.AmplitudeDamping.compute_kraus_matrices, qml.BitFlip.compute_kraus_matrices]
single_qubit_channels = [
qml.DepolarizingChannel.compute_kraus_matrices,
qml.AmplitudeDamping.compute_kraus_matrices,
qml.BitFlip.compute_kraus_matrices,
]


def apply_single_qubit_unitary(state, op_info):
wire, param = op_info["wires"][0], op_info["params"][0]
Expand All @@ -136,6 +143,9 @@ def apply_single_qubit_channel(state, op_info):
pass


qubit_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary, apply_single_qubit_channel]


single_qutrit_ops = [qml.TRX.compute_matrix, qml.TRY.compute_matrix, qml.TRZ.compute_matrix]
single_qutrit_channels = [
lambda params: qml.QutritDepolarizingChannel.compute_kraus_matrices(params[0]),
Expand All @@ -152,15 +162,19 @@ def apply_single_qutrit_unitary(state, op_info):

def apply_two_qutrit_unitary(state, op_info):
wires = op_info["wires"]
kraus_mat = jnp.array([[1,0,0,0,0,0,0,0,0],
[0,1,0,0,0,0,0,0,0],
[0,0,1,0,0,0,0,0,0],
[0,0,0,0,0,1,0,0,0],
[0,0,0,1,0,0,0,0,0],
[0,0,0,0,1,0,0,0,0],
[0,0,0,0,0,0,0,1,0],
[0,0,0,0,0,0,0,0,1],
[0,0,0,0,0,0,1,0,0]])
kraus_mat = jnp.array(
[
[1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0, 0],
]
)
pass


Expand All @@ -170,25 +184,8 @@ def apply_single_qutrit_channel(state, op_info):
pass


def get_operation_applier(qudit_dim):
qubit_type_branches = [apply_single_qubit_unitary, apply_two_qubit_unitary,
apply_single_qubit_channel]
qutrit_type_branches = [apply_single_qutrit_unitary, apply_two_qutrit_unitary,
apply_single_qutrit_channel]
if qudit_dim == 2:
def operation_applier(state, op_info):
index_cutoffs = [0, 0, 0]
op_i = op_info["type_index"]
op_class = op_i // index_cutoffs[0] + op_i // index_cutoffs[1] + op_i // index_cutoffs[2]
return jax.lax.switch(op_class, qubit_type_branches, state, op_info), None
elif qudit_dim == 3:
def operation_applier(state, op_info):
index_cutoffs = [0, 0, 0]
op_i = op_info["type_index"]
op_class = op_i // index_cutoffs[0] + op_i // index_cutoffs[1] + op_i // index_cutoffs[2]
return jax.lax.switch(op_class, qutrit_type_branches, state, op_info), None
else:
raise ValueError("Only qubit and qutrit simulators are allowed")

return operation_applier

qutrit_branches = [
apply_single_qutrit_unitary,
apply_two_qutrit_unitary,
apply_single_qutrit_channel,
]
109 changes: 27 additions & 82 deletions pennylane/devices/qtcorgi_helper/qtcorgi_simulator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import functools
import jax
import jax.numpy as jnp
import pennylane as qml
from pennylane.tape import QuantumScript
from .apply_operations import get_operation_applier
from ..qutrit_mixed.simulate import measure_final_state
from .apply_operations import qubit_branches, qutrit_branches
from ..qutrit_mixed.initialize_state import create_initial_state

op_types = []
Expand All @@ -15,7 +13,7 @@ def get_qubit_final_state_from_initial(operations, initial_state):
TODO
Args:
circuit (.QuantumScript): The single circuit to simulate
TODO
Returns:
Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and
Expand All @@ -27,7 +25,6 @@ def get_qubit_final_state_from_initial(operations, initial_state):

wires = op.wires()


if isinstance(op, qml.operation.Channel):
ops_type_indices[0].append(2)
ops_type_indices[1].append([].index(type(op)))
Expand All @@ -38,44 +35,39 @@ def get_qubit_final_state_from_initial(operations, initial_state):
ops_type_indices[0].append(0)
ops_type_indices[1].append([].index(type(op)))



if len(wires) == 1:
wires = [wires[0], -1]
params = op.parameters + ([0] * (3 - op.num_params))
if len(wires) == 2:
ops_wires[0].append(wires[0])
ops_wires[0].append(wires[0])
ops_wires[1].append(wires[1])

ops_param[0].append(params[0])

op_index = op_types.index(type(op))
ops_type_index.append(op_index)



if qudit_dim == 2 and op_index <= 2:
ops_subspace.append([(0,1), (0,2), (1,2)].index(op.subspace))
else:
ops_subspace.append(0)

ops_info = {
"type_index": jnp.array(ops_type_index),
"wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])],
"params": [jnp.array(ops_param)]
"params": [jnp.array(ops_param)],
}

return jax.lax.scan(
lambda state, op_info: (
jax.lax.switch(op_info["branch"], qubit_branches, state, op_info),
None,
),
initial_state,
ops_info,
)[0]


return jax.lax.scan(get_operation_applier(qudit_dim), initial_state, ops_info)[0]


def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim):
def get_qutrit_final_state_from_initial(operations, initial_state):
"""
TODO
Args:
circuit (.QuantumScript): The single circuit to simulate
TODO
Returns:
Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and
Expand All @@ -84,59 +76,46 @@ def get_qutrit_final_state_from_initial(operations, initial_state, qudit_dim):
"""
ops_type_index, ops_subspace, ops_wires, ops_params = [], [], [[], []], [[], [], []]
for op in operations:

# op_index = None
# for i, op_type in enumerate(op_types):
# if isinstance(op, op_type):
# op_index = i
# if op_index is None:
# raise ValueError("This simulator only supports")

# op_index = op_types.index(type(op))
# wires = op.wires()
# if len(wires) == 1:
# wires = [-1, wires[0]]
# if len(wires) == 2:
# wires = list(wires)
# params = op.parameters + ([0] * (3-op.num_params))
# op_array.append([[op_index] + wires, params])

op_index = op_types.index(type(op))
ops_type_index.append(op_index)

wires = op.wires()
if len(wires) == 1:
wires = [wires[0], -1]
params = op.parameters + ([0] * (3 - op.num_params))
if len(wires) == 2:

ops_wires[0].append(wires[0])
ops_wires[1].append(wires[1])

ops_params[0].append(params[0])
ops_params[1].append(params[1])
ops_params[2].append(params[2])

if qudit_dim == 2 and op_index <= 2:
ops_subspace.append([(0,1), (0,2), (1,2)].index(op.subspace))
if op_index <= 2:
ops_subspace.append([(0, 1), (0, 2), (1, 2)].index(op.subspace))
else:
ops_subspace.append(0)

ops_info = {
"type_index": jnp.array(ops_type_index),
"wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])],
"params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])]
"params": [jnp.array(ops_params[0]), jnp.array(ops_params[1]), jnp.array(ops_params[2])],
}
return jax.lax.scan(get_operation_applier(qudit_dim), initial_state, ops_info)[0]
op_branch = jnp.nan

return jax.lax.scan(
lambda state, op_info: (jax.lax.switch(op_info["branch"], qutrit_branches, state, x), None),
initial_state,
ops_info,
)[0]

def get_final_state(circuit):

def get_final_state_qutrit(circuit):
"""
TODO
Args:
circuit (.QuantumScript): The single circuit to simulate
qudit_dim (): TODO
interface (str): The machine learning interface to create the initial state with
Returns:
Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and
Expand All @@ -151,38 +130,4 @@ def get_final_state(circuit):
prep = circuit[0]

state = create_initial_state(sorted(circuit.op_wires), prep, like="jax")
get_final_state_from_initial(circuit.operations[bool(prep):], state, 3)





def simulate(circuit: QuantumScript, rng=None, prng_key=None):
"""TODO
Args:
circuit (QuantumTape): The single circuit to simulate
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. If None, a random key will be
generated. Only for simulation using JAX.
debugger (_Debugger): The debugger to use
interface (str): The machine learning interface to create the initial state with
Returns:
tuple(TensorLike): The results of the simulation
Note that this function can return measurements for non-commuting observables simultaneously.
This function assumes that all operations provide matrices.
>>> qs = qml.tape.QuantumScript([qml.TRX(1.2, wires=0)], [qml.expval(qml.GellMann(0, 3)), qml.probs(wires=(0,1))])
>>> simulate(qs)
(0.36235775447667357,
tensor([0.68117888, 0. , 0. , 0.31882112, 0. , 0. ], requires_grad=True))
"""
state, is_state_batched = get_final_state(circuit)
return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key)
return get_qutrit_final_state_from_initial(circuit.operations[bool(prep) :], state), False
Loading

0 comments on commit f1679f4

Please sign in to comment.