Skip to content

Commit

Permalink
Added test for new observable stopping condition
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel-Bottrill committed Apr 26, 2024
1 parent 25a2b20 commit 51ac8bb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
8 changes: 5 additions & 3 deletions pennylane/devices/default_qutrit_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@

def observable_stopping_condition(obs: qml.operation.Operator) -> bool:
"""Specifies whether an observable is accepted by DefaultQutritMixed."""
if obs.name in {"SProd", "Prod", "Sum"}:
if isinstance(obs, qml.operation.Tensor):
return all(observable_stopping_condition(observable) for observable in obs.obs)
if obs.name in {"Prod", "Sum"}:
return all(observable_stopping_condition(observable) for observable in obs.operands)
if obs.name in {"LinearCombination", "Hamiltonian"}:
return all(observable_stopping_condition(observable) for observable in obs.terms()[1])
if obs.name == "Tensor":
return all(observable_stopping_condition(observable) for observable in obs.obs)
if obs.name == "SProd":
return observable_stopping_condition(obs.base)

return obs.name in observables

Expand Down
27 changes: 24 additions & 3 deletions tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

import pennylane as qml
from pennylane.devices import ExecutionConfig
from pennylane.devices.default_qutrit_mixed import DefaultQutritMixed, stopping_condition
from pennylane.devices.default_qutrit_mixed import (
DefaultQutritMixed,
stopping_condition,
observable_stopping_condition,
)


class NoMatOp(qml.operation.Operation):
Expand Down Expand Up @@ -124,11 +128,28 @@ def test_measurement_is_swapped_out(self, mp_fn, mp_cls, shots):
(qml.TRX(1.1, 0), True),
],
)
def test_accepted_operator(self, op, expected): # TODO: Add channel ops once added.
"""Test that _accepted_operator works correctly"""
def test_accepted_observables(self, op, expected):
"""Test that stopping_condition works correctly"""
res = stopping_condition(op)
assert res == expected

@pytest.mark.parametrize(
"obs, expected",
[
(qml.TShift(0), False),
(qml.GellMann(0, 1), True),
(qml.Snapshot(), False),
(qml.operation.Tensor(qml.GellMann(0, 1), qml.GellMann(3, 3)), True),
(qml.ops.op_math.SProd(1.2, qml.GellMann(0, 1)), True),
(qml.sum(qml.ops.op_math.SProd(1.2, qml.GellMann(0, 1)), qml.GellMann(1, 3)), True),
(qml.ops.op_math.Prod(qml.GellMann(0, 1), qml.GellMann(3, 3)), True),
],
)
def test_accepted_operator(self, obs, expected):
"""Test that observable_stopping_condition works correctly"""
res = observable_stopping_condition(obs)
assert res == expected


class TestPreprocessingIntegration:
"""Test preprocess produces output that can be executed by the device."""
Expand Down

0 comments on commit 51ac8bb

Please sign in to comment.