diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4bfec46955e..bbeed1e3e34 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -567,6 +567,9 @@ same information.
Bug fixes 🐛
+* `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present.
+ [(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732)
+
* `qml.ControlledQubitUnitary` has consistent behaviour with program capture enabled.
[(#6719)](https://github.com/PennyLaneAI/pennylane/pull/6719)
diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py
index a9fc151b75a..4328c21bdd2 100644
--- a/pennylane/transforms/dynamic_one_shot.py
+++ b/pennylane/transforms/dynamic_one_shot.py
@@ -385,6 +385,17 @@ def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None):
"""
if isinstance(measurement, CountsMP):
tmp = Counter()
+
+ if measurement.all_outcomes:
+ if isinstance(measurement.mv, Sequence):
+ values = [list(m.branches.values()) for m in measurement.mv]
+ values = list(itertools.product(*values))
+ tmp = Counter({"".join(map(str, v)): 0 for v in values})
+ else:
+ values = [list(measurement.mv.branches.values())]
+ values = list(itertools.product(*values))
+ tmp = Counter({float(*v): 0 for v in values})
+
for i, d in enumerate(samples):
tmp.update(
{k if isinstance(k, str) else float(k): v * is_valid[i] for k, v in d.items()}
diff --git a/tests/measurements/test_counts.py b/tests/measurements/test_counts.py
index e49ab042718..bb919524ee9 100644
--- a/tests/measurements/test_counts.py
+++ b/tests/measurements/test_counts.py
@@ -396,6 +396,34 @@ def test_counts_binsize(self):
class TestCountsIntegration:
# pylint:disable=too-many-public-methods,not-an-iterable
+ def test_counts_all_outcomes_with_mcm(self):
+ n_sample = 10
+
+ dev = qml.device("default.qubit", shots=n_sample)
+
+ @qml.qnode(device=dev, mcm_method="one-shot")
+ def single_mcm():
+ m = qml.measure(0)
+ return qml.counts(m, all_outcomes=True)
+
+ res = single_mcm()
+
+ assert list(res.keys()) == [0.0, 1.0]
+ assert sum(res.values()) == n_sample
+ assert res[0.0] == n_sample
+
+ @qml.qnode(device=dev, mcm_method="one-shot")
+ def double_mcm():
+ m1 = qml.measure(0)
+ m2 = qml.measure(1)
+ return qml.counts([m1, m2], all_outcomes=True)
+
+ res = double_mcm()
+
+ assert list(res.keys()) == ["00", "01", "10", "11"]
+ assert sum(res.values()) == n_sample
+ assert res["00"] == n_sample
+
def test_counts_dimension(self):
"""Test that the counts function outputs counts of the right size"""
n_sample = 10
diff --git a/tests/test_compiler.py b/tests/test_compiler.py
index 27a50a837f8..38d61cd86cf 100644
--- a/tests/test_compiler.py
+++ b/tests/test_compiler.py
@@ -864,8 +864,7 @@ class TestCatalystMCMs:
@pytest.mark.parametrize(
"measure_f",
[
- # https://github.com/PennyLaneAI/pennylane/issues/6700
- pytest.param(qml.counts, marks=pytest.mark.xfail),
+ qml.counts,
qml.expval,
qml.probs,
],
@@ -898,7 +897,7 @@ def ref_func(x, y):
meas_key = "wires" if isinstance(meas_obj, list) else "op"
meas_value = m0 if isinstance(meas_obj, str) else meas_obj
kwargs = {meas_key: meas_value}
- if measure_f == qml.counts:
+ if measure_f is qml.counts:
kwargs["all_outcomes"] = True
return measure_f(**kwargs)
@@ -919,17 +918,20 @@ def ansatz():
meas_key = "wires" if isinstance(meas_obj, list) else "op"
meas_value = m0 if isinstance(meas_obj, str) else meas_obj
kwargs = {meas_key: meas_value}
- if measure_f == qml.counts:
+ if measure_f is qml.counts:
kwargs["all_outcomes"] = True
return measure_f(**kwargs)
params = jnp.pi / 4 * jnp.ones(2)
results0 = ref_func(*params)
results1 = func(*params)
- if measure_f == qml.counts:
- ndim = 2 # both [0] and m0 are on one wire only
- results1 = {format(int(state), f"0{ndim}b"): count for state, count in zip(*results1)}
- if measure_f == qml.sample:
+ if measure_f is qml.counts:
+
+ def fname(x):
+ return format(x, f"0{len(meas_obj)}b") if isinstance(meas_obj, list) else x
+
+ results1 = {fname(int(state)): count for state, count in zip(*results1)}
+ if measure_f is qml.sample:
results0 = results0[results0 != fill_in_value]
results1 = results1[results1 != fill_in_value]
mcm_utils.validate_measurements(measure_f, shots, results1, results0)