Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jittable mixed simulation #4

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a0b1f17
Need to push quick
Gabriel-Bottrill Jun 27, 2024
89d9e4e
Worked on getting simulate setup, need to do apply gates
Gabriel-Bottrill Jun 30, 2024
91f2cf3
Working towards new devices
Gabriel-Bottrill Jul 2, 2024
2670761
Added logic to apply_operations and working on splitting qubit and qu…
Gabriel-Bottrill Jul 2, 2024
f1679f4
Linked new jittable simulation to old device
Gabriel-Bottrill Jul 3, 2024
a11206c
Fixed indexing and linked apply operation.
Gabriel-Bottrill Jul 5, 2024
f1c0749
Removed extra files
Gabriel-Bottrill Jul 5, 2024
432dc78
removed makefile to stop CI
Gabriel-Bottrill Jul 5, 2024
da2e63d
Removed commenting out Makefile
Gabriel-Bottrill Jul 5, 2024
9044145
Removed simulate file and added simulator to devices for more one to one
Gabriel-Bottrill Jul 5, 2024
223ed07
[ci skip] Added support for Adjoint of TAdd
Gabriel-Bottrill Jul 8, 2024
417d2b0
[ci skip] Merge changes for adjoint
Gabriel-Bottrill Jul 8, 2024
c0f5a58
Worked on apply operations einsum, I think that generating indexing n…
Gabriel-Bottrill Jul 15, 2024
d969cd6
Changed einsum indices to not index via scan abstract values
Gabriel-Bottrill Jul 16, 2024
2cb5985
Changed how different matrices are made
Gabriel-Bottrill Jul 16, 2024
aac0ed8
Got to einsum, not working...
Gabriel-Bottrill Jul 16, 2024
ebcaf57
Fixed jittability speed for qutrit need to finish qubit
Gabriel-Bottrill Jul 24, 2024
1cf95dd
Added logic for most qubit gates
Gabriel-Bottrill Jul 25, 2024
f25acc4
Got everything working
Gabriel-Bottrill Jul 26, 2024
9166ea1
Reformatted
Gabriel-Bottrill Jul 26, 2024
5c761ca
Added qubitunitary to operators
Gabriel-Bottrill Jul 26, 2024
cd332c0
Removed print
Gabriel-Bottrill Jul 26, 2024
73160f4
Removed unused function from simulate
Gabriel-Bottrill Jul 29, 2024
a16997f
Removed unused module qtcorgi simulator
Gabriel-Bottrill Jul 29, 2024
5e1229e
Added jitting to more functions
Gabriel-Bottrill Jul 29, 2024
85fe88c
Merged master back in
Gabriel-Bottrill Jul 30, 2024
9a0277d
Fixed index of THadamard param
Gabriel-Bottrill Jul 31, 2024
7a0653c
Changed method for dealing with TRX,TRY, and TRZ ops
Gabriel-Bottrill Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 94 additions & 58 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,94 @@
from pennylane.wires import Wires

from .._version import __version__
import jax
import jax.numpy as jnp
from .qtcorgi_helper.apply_operations import qubit_branches


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

ABC_ARRAY = np.array(list(ABC))
tolerance = 1e-10

two_gates = [
"THadamard",
"TRX_01",
"TRY_01",
"TRZ_01",
"TRX_02",
"TRY_02",
"TRZ_02",
"TRX_12",
"TRY_12",
"TRZ_12",
]


def get_qubit_final_state_from_initial(operations, initial_state):
"""
TODO

Args:
operations ():TODO
initial_state ():TODO

Returns:
Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and
whether the state has a batch dimension.

"""
ops_type_indices, ops_wires, ops_param = [[], []], [[], []], []

two_qubit_ops = False
for op in operations:

wires = op.wires

if isinstance(op, Channel):
ops_type_indices[0].append(1)
ops_type_indices[1].append(
[qml.DepolarizingChannel, qml.AmplitudeDamping, qml.BitFlip].index(type(op))
)
elif len(wires) == 1:
ops_type_indices[0].append(0)
ops_type_indices[1].append([qml.RX, qml.RY, qml.RZ, qml.Hadamard].index(type(op)))
elif len(wires) == 2:
ops_type_indices[0].append(2)
if isinstance(op, qml.CNOT):
op_index = 0
else:
op_index = two_gates.index(op.id) + 1
if op_index < 2:
ops_param.append(0.0)
else:
ops_param.append(op.phi)
two_qubit_ops = True
ops_type_indices[1].append(op_index)

else:
raise ValueError("TODO")

if len(wires) == 1:
wires = [wires[0], -1]
ops_param.append(op.parameters[0])
ops_wires[0].append(wires[0])
ops_wires[1].append(wires[1])

ops_info = {
"type_indices": jnp.array(ops_type_indices).T,
"wires": [jnp.array(ops_wires[0]), jnp.array(ops_wires[1])],
"param": jnp.array(ops_param),
}
branches = qubit_branches[: 2 + two_qubit_ops]

@jax.jit
def switch_function(state, op_info):
return jax.lax.switch(op_info["type_indices"][0], branches, state, op_info), None

return jax.lax.scan(switch_function, initial_state, ops_info)[0]


class DefaultMixed(QubitDevice):
"""Default qubit device for performing mixed-state computations in PennyLane.
Expand Down Expand Up @@ -92,73 +173,19 @@ class DefaultMixed(QubitDevice):
author = "Xanadu Inc."

operations = {
"Identity",
"Snapshot",
"BasisState",
"QubitStateVector",
"StatePrep",
"QubitDensityMatrix",
"QubitUnitary",
"ControlledQubitUnitary",
"BlockEncode",
"MultiControlledX",
"DiagonalQubitUnitary",
"SpecialUnitary",
"PauliX",
"PauliY",
"PauliZ",
"MultiRZ",
"Hadamard",
"S",
"T",
"SX",
"CNOT",
"SWAP",
"ISWAP",
"CSWAP",
"Toffoli",
"CCZ",
"CY",
"CZ",
"CH",
"PhaseShift",
"PCPhase",
"ControlledPhaseShift",
"CPhaseShift00",
"CPhaseShift01",
"CPhaseShift10",
"QubitUnitary",
"RX",
"RY",
"RZ",
"Rot",
"CRX",
"CRY",
"CRZ",
"CRot",
"AmplitudeDamping",
"GeneralizedAmplitudeDamping",
"PhaseDamping",
"DepolarizingChannel",
"BitFlip",
"PhaseFlip",
"PauliError",
"ResetError",
"QubitChannel",
"SingleExcitation",
"SingleExcitationPlus",
"SingleExcitationMinus",
"DoubleExcitation",
"DoubleExcitationPlus",
"DoubleExcitationMinus",
"QubitCarry",
"QubitSum",
"OrbitalRotation",
"FermionicSWAP",
"QFT",
"ThermalRelaxationError",
"ECR",
"ParametrizedEvolution",
"GlobalPhase",
}

_reshape = staticmethod(qnp.reshape)
Expand Down Expand Up @@ -781,9 +808,18 @@ def apply(self, operations, rotations=None, **kwargs):
f"on a {self.short_name} device."
)

for operation in operations:
self._apply_operation(operation)

prep = False
if len(operations) > 0:
if (
isinstance(operations[0], StatePrep)
or isinstance(operations[0], BasisState)
or isinstance(operations[0], QubitDensityMatrix)
):
prep = True
self._apply_operation(operation)

self._state = jnp.complex128(self._state)
self._state = get_qubit_final_state_from_initial(operations[prep:], self._state)
# store the pre-rotated state
self._pre_rotated_state = self._state

Expand Down
10 changes: 9 additions & 1 deletion pennylane/devices/default_qutrit_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ def observable_stopping_condition(obs: qml.operation.Operator) -> bool:

def stopping_condition(op: qml.operation.Operator) -> bool:
"""Specify whether an Operator object is supported by the device."""
expected_set = DefaultQutrit.operations | {"Snapshot"} | channels
operations = {
"TAdd",
"Adjoint(TAdd)",
"THadamard",
"TRX",
"TRY",
"TRZ",
}
expected_set = operations | channels
return op.name in expected_set


Expand Down
1 change: 0 additions & 1 deletion pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def decomposer(op):
if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
return (tape,), null_postprocessing
try:

new_ops = [
final_op
for op in tape.operations[len(prep_op) :]
Expand Down
Empty file.
Loading
Loading