Skip to content

Commit

Permalink
Add eigvals support to ExpectationMP, CountsMP, VarianceMP, SampleMP (P…
Browse files Browse the repository at this point in the history
…ennyLaneAI#5463)

**Context:**
`MeasurementProcess` supports specifying `eigvals` instead of `obs`,
which should include sample based measurements such as `ExpectationMP`,
`VarianceMP`. However, `process_samples` of these measurements does not
work if the measurement was created using `eigvals`

**Description of the Change:**
Adds `eigvals` support to variance sample measurements

**Benefits:**
Bugfix

**Possible Drawbacks:**
This fix only makes sure that the `process_samples` method would work.
Specifically for `expval`, as it is also a state measurement, currently
taking a `expval` in a `qnode` with only `eigvals` specified does not
work.

**Related GitHub Issues:**
Fixes PennyLaneAI#5432
[sc-59519]

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
astralcai and albi3ro authored Apr 9, 2024
1 parent f10e98f commit 1089bb3
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 35 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@
differentiation.
[(#5434)](https://github.com/PennyLaneAI/pennylane/pull/5434)

* `SampleMP`, `ExpectationMP`, `CountsMP`, `VarianceMP` constructed with ``eigvals`` can now properly process samples.
[(#5463)](https://github.com/PennyLaneAI/pennylane/pull/5463)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
21 changes: 18 additions & 3 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from .mid_measure import MeasurementValue


def counts(op=None, wires=None, all_outcomes=False) -> "CountsMP":
def counts(
op=None,
wires=None,
all_outcomes=False,
) -> "CountsMP":
r"""Sample from the supplied observable, with the number of shots
determined from the ``dev.shots`` attribute of the corresponding device,
returning the number of counts for each sample. If no observable is provided then basis state
Expand Down Expand Up @@ -148,8 +152,8 @@ def circuit():
if wires is not None:
if op is not None:
raise ValueError(
"Cannot specify the wires to sample if an observable is "
"provided. The wires to sample will be determined directly from the observable."
"Cannot specify the wires to sample if an observable is provided. The wires "
"to sample will be determined directly from the observable."
)
wires = Wires(wires)

Expand Down Expand Up @@ -186,6 +190,8 @@ def __init__(
all_outcomes: bool = False,
):
self.all_outcomes = all_outcomes
if wires is not None:
wires = Wires(wires)
super().__init__(obs, wires, eigvals, id)

def _flatten(self):
Expand Down Expand Up @@ -340,6 +346,15 @@ def convert(x):
for state, count in zip(qml.math.unwrap(states), _counts):
outcome_dict[state] = count

def outcome_to_eigval(outcome: str):
return self.eigvals()[int(outcome, 2)]

if self._eigvals is not None:
outcome_dicts = [
{outcome_to_eigval(outcome): count for outcome, count in outcome_dict.items()}
for outcome_dict in outcome_dicts
]

return outcome_dicts if batched else outcome_dicts[0]

# pylint: disable=redefined-outer-name
Expand Down
12 changes: 10 additions & 2 deletions pennylane/measurements/expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import pennylane as qml
from pennylane.operation import Operator
from pennylane.wires import Wires

from .measurements import Expectation, SampleMeasurement, StateMeasurement
from .sample import SampleMP
from .mid_measure import MeasurementValue


def expval(op: Union[Operator, MeasurementValue]):
def expval(
op: Union[Operator, MeasurementValue],
):
r"""Expectation value of the supplied observable.
**Example:**
Expand Down Expand Up @@ -114,7 +118,11 @@ def process_samples(
# estimate the ev
op = self.mv if self.mv is not None else self.obs
with qml.queuing.QueuingManager.stop_recording():
samples = qml.sample(op=op).process_samples(
samples = SampleMP(
obs=op,
eigvals=self._eigvals,
wires=self.wires if self._eigvals is not None else None,
).process_samples(
samples=samples, wire_order=wire_order, shot_range=shot_range, bin_size=bin_size
)

Expand Down
61 changes: 35 additions & 26 deletions pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from .mid_measure import MeasurementValue


def sample(op: Optional[Union[Operator, MeasurementValue]] = None, wires=None) -> "SampleMP":
def sample(
op: Optional[Union[Operator, MeasurementValue]] = None,
wires=None,
) -> "SampleMP":
r"""Sample from the supplied observable, with the number of shots
determined from the ``dev.shots`` attribute of the corresponding device,
returning raw samples. If no observable is provided then basis state samples are returned
Expand Down Expand Up @@ -132,30 +135,6 @@ def circuit(x):
[0, 0]])
"""
if isinstance(op, MeasurementValue):
return SampleMP(obs=op)

if isinstance(op, Sequence):
if not all(isinstance(o, MeasurementValue) and len(o.measurements) == 1 for o in op):
raise qml.QuantumFunctionError(
"Only sequences of single MeasurementValues can be passed with the op argument. "
"MeasurementValues manipulated using arithmetic operators cannot be used when "
"collecting statistics for a sequence of mid-circuit measurements."
)

return SampleMP(obs=op)

if op is not None and not op.is_hermitian: # None type is also allowed for op
warnings.warn(f"{op.name} might not be hermitian.")

if wires is not None:
if op is not None:
raise ValueError(
"Cannot specify the wires to sample if an observable is "
"provided. The wires to sample will be determined directly from the observable."
)
wires = Wires(wires)

return SampleMP(obs=op, wires=wires)


Expand All @@ -177,6 +156,36 @@ class SampleMP(SampleMeasurement):
where the instance has to be identified
"""

def __init__(self, obs=None, wires=None, eigvals=None, id=None):

if isinstance(obs, MeasurementValue):
super().__init__(obs=obs)
return

if isinstance(obs, Sequence):
if not all(isinstance(o, MeasurementValue) and len(o.measurements) == 1 for o in obs):
raise qml.QuantumFunctionError(
"Only sequences of single MeasurementValues can be passed with the op "
"argument. MeasurementValues manipulated using arithmetic operators cannot be "
"used when collecting statistics for a sequence of mid-circuit measurements."
)

super().__init__(obs=obs)
return

if obs is not None and not obs.is_hermitian: # None type is also allowed for op
warnings.warn(f"{obs.name} might not be hermitian.")

if wires is not None:
if obs is not None:
raise ValueError(
"Cannot specify the wires to sample if an observable is provided. The wires "
"to sample will be determined directly from the observable."
)
wires = Wires(wires)

super().__init__(obs=obs, wires=wires, eigvals=eigvals, id=id)

@property
def return_type(self):
return Sample
Expand Down Expand Up @@ -249,7 +258,7 @@ def process_samples(
num_wires = samples.shape[-1] # wires is the last dimension

# If we're sampling wires or a list of mid-circuit measurements
if self.obs is None and not isinstance(self.mv, MeasurementValue):
if self.obs is None and not isinstance(self.mv, MeasurementValue) and self._eigvals is None:
# if no observable was provided then return the raw samples
return samples if bin_size is None else samples.T.reshape(num_wires, bin_size, -1)

Expand Down
9 changes: 8 additions & 1 deletion pennylane/measurements/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import pennylane as qml
from pennylane.operation import Operator
from pennylane.wires import Wires

from .measurements import SampleMeasurement, StateMeasurement, Variance
from .sample import SampleMP
from .mid_measure import MeasurementValue


Expand Down Expand Up @@ -64,6 +66,7 @@ def circuit(x):

if not op.is_hermitian:
warnings.warn(f"{op.name} might not be hermitian.")

return VarianceMP(obs=op)


Expand Down Expand Up @@ -108,7 +111,11 @@ def process_samples(
# estimate the variance
op = self.mv if self.mv is not None else self.obs
with qml.queuing.QueuingManager.stop_recording():
samples = qml.sample(op=op).process_samples(
samples = SampleMP(
obs=op,
eigvals=self._eigvals,
wires=self.wires if self._eigvals is not None else None,
).process_samples(
samples=samples, wire_order=wire_order, shot_range=shot_range, bin_size=bin_size
)

Expand Down
15 changes: 13 additions & 2 deletions tests/measurements/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_repr(self):
m2 = CountsMP(obs=qml.PauliX(0), all_outcomes=True)
assert repr(m2) == "CountsMP(X(0), all_outcomes=True)"

m3 = CountsMP(eigvals=(-1, 1), all_outcomes=False)
assert repr(m3) == "CountsMP(eigvals=[-1 1], wires=[], all_outcomes=False)"
m3 = CountsMP(eigvals=(-1, 1), wires=[0], all_outcomes=False)
assert repr(m3) == "CountsMP(eigvals=[-1 1], wires=[0], all_outcomes=False)"

mv = qml.measure(0)
m4 = CountsMP(obs=mv, all_outcomes=False)
Expand Down Expand Up @@ -164,6 +164,17 @@ def test_counts_obs(self):
assert result[1] == np.count_nonzero(samples[:, 0] == 0)
assert result[-1] == np.count_nonzero(samples[:, 0] == 1)

def test_count_eigvals(self):
"""Tests that eigvals are used instead of obs for counts"""

shots = 100
samples = np.random.choice([0, 1], size=(shots, 2)).astype(np.int64)
result = CountsMP(eigvals=[1, -1], wires=0).process_samples(samples, wire_order=[0])
assert len(result) == 2
assert set(result.keys()) == {1, -1}
assert result[1] == np.count_nonzero(samples[:, 0] == 0)
assert result[-1] == np.count_nonzero(samples[:, 0] == 1)

def test_counts_shape_single_measurement_value(self):
"""Test that the counts output is correct for single mid-circuit measurement
values."""
Expand Down
10 changes: 10 additions & 0 deletions tests/measurements/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ def expected_circuit(phi):
res = func(phi, shots=shots)
assert np.allclose(np.array(res), expected, atol=atol, rtol=0)

def test_eigvals_instead_of_observable(self):
"""Tests process samples with eigvals instead of observables"""

shots = 100
samples = np.random.choice([0, 1], size=(shots, 2)).astype(np.int64)
expected = qml.expval(qml.PauliZ(0)).process_samples(samples, [0, 1])
assert (
ExpectationMP(eigvals=[1, -1], wires=[0]).process_samples(samples, [0, 1]) == expected
)

def test_measurement_value_list_not_allowed(self):
"""Test that measuring a list of measurement values raises an error."""
m0 = qml.measure(0)
Expand Down
10 changes: 9 additions & 1 deletion tests/measurements/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest

import pennylane as qml
from pennylane.measurements import Variance, Shots
from pennylane.measurements import Variance, Shots, VarianceMP


class TestVar:
Expand Down Expand Up @@ -135,6 +135,14 @@ def expected_circuit(phi):
res = func(phi, shots=shots)
assert np.allclose(np.array(res), expected, atol=atol, rtol=0)

def test_eigvals_instead_of_observable(self):
"""Tests process samples with eigvals instead of observables"""

shots = 100
samples = np.random.choice([0, 1], size=(shots, 2)).astype(np.int64)
expected = qml.var(qml.PauliZ(0)).process_samples(samples, [0, 1])
assert VarianceMP(eigvals=[1, -1], wires=[0]).process_samples(samples, [0, 1]) == expected

def test_measurement_value_list_not_allowed(self):
"""Test that measuring a list of measurement values raises an error."""
m0 = qml.measure(0)
Expand Down

0 comments on commit 1089bb3

Please sign in to comment.