diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4bfec46955e..bbeed1e3e34 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -567,6 +567,9 @@ same information.

Bug fixes 🐛

+* `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present. + [(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732) + * `qml.ControlledQubitUnitary` has consistent behaviour with program capture enabled. [(#6719)](https://github.com/PennyLaneAI/pennylane/pull/6719) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index a9fc151b75a..4328c21bdd2 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -385,6 +385,17 @@ def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None): """ if isinstance(measurement, CountsMP): tmp = Counter() + + if measurement.all_outcomes: + if isinstance(measurement.mv, Sequence): + values = [list(m.branches.values()) for m in measurement.mv] + values = list(itertools.product(*values)) + tmp = Counter({"".join(map(str, v)): 0 for v in values}) + else: + values = [list(measurement.mv.branches.values())] + values = list(itertools.product(*values)) + tmp = Counter({float(*v): 0 for v in values}) + for i, d in enumerate(samples): tmp.update( {k if isinstance(k, str) else float(k): v * is_valid[i] for k, v in d.items()} diff --git a/tests/measurements/test_counts.py b/tests/measurements/test_counts.py index e49ab042718..bb919524ee9 100644 --- a/tests/measurements/test_counts.py +++ b/tests/measurements/test_counts.py @@ -396,6 +396,34 @@ def test_counts_binsize(self): class TestCountsIntegration: # pylint:disable=too-many-public-methods,not-an-iterable + def test_counts_all_outcomes_with_mcm(self): + n_sample = 10 + + dev = qml.device("default.qubit", shots=n_sample) + + @qml.qnode(device=dev, mcm_method="one-shot") + def single_mcm(): + m = qml.measure(0) + return qml.counts(m, all_outcomes=True) + + res = single_mcm() + + assert list(res.keys()) == [0.0, 1.0] + assert sum(res.values()) == n_sample + assert res[0.0] == n_sample + + @qml.qnode(device=dev, mcm_method="one-shot") + def double_mcm(): + m1 = qml.measure(0) + m2 = qml.measure(1) + return qml.counts([m1, m2], all_outcomes=True) + + res = double_mcm() + + assert list(res.keys()) == ["00", "01", "10", "11"] + assert sum(res.values()) == n_sample + assert res["00"] == n_sample + def test_counts_dimension(self): """Test that the counts function outputs counts of the right size""" n_sample = 10 diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 27a50a837f8..38d61cd86cf 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -864,8 +864,7 @@ class TestCatalystMCMs: @pytest.mark.parametrize( "measure_f", [ - # https://github.com/PennyLaneAI/pennylane/issues/6700 - pytest.param(qml.counts, marks=pytest.mark.xfail), + qml.counts, qml.expval, qml.probs, ], @@ -898,7 +897,7 @@ def ref_func(x, y): meas_key = "wires" if isinstance(meas_obj, list) else "op" meas_value = m0 if isinstance(meas_obj, str) else meas_obj kwargs = {meas_key: meas_value} - if measure_f == qml.counts: + if measure_f is qml.counts: kwargs["all_outcomes"] = True return measure_f(**kwargs) @@ -919,17 +918,20 @@ def ansatz(): meas_key = "wires" if isinstance(meas_obj, list) else "op" meas_value = m0 if isinstance(meas_obj, str) else meas_obj kwargs = {meas_key: meas_value} - if measure_f == qml.counts: + if measure_f is qml.counts: kwargs["all_outcomes"] = True return measure_f(**kwargs) params = jnp.pi / 4 * jnp.ones(2) results0 = ref_func(*params) results1 = func(*params) - if measure_f == qml.counts: - ndim = 2 # both [0] and m0 are on one wire only - results1 = {format(int(state), f"0{ndim}b"): count for state, count in zip(*results1)} - if measure_f == qml.sample: + if measure_f is qml.counts: + + def fname(x): + return format(x, f"0{len(meas_obj)}b") if isinstance(meas_obj, list) else x + + results1 = {fname(int(state)): count for state, count in zip(*results1)} + if measure_f is qml.sample: results0 = results0[results0 != fill_in_value] results1 = results1[results1 != fill_in_value] mcm_utils.validate_measurements(measure_f, shots, results1, results0)