Skip to content

Commit

Permalink
Improving support of counts and sample with default.mixed (Penn…
Browse files Browse the repository at this point in the history
…yLaneAI#5514)

**Context:** `Counts` should work with interfaces (torch, tf, jax) and
`default.mixed` but it is not the case. This PR fixes it.

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:** PennyLaneAI#2984

---------

Co-authored-by: Jay Soni <[email protected]>
  • Loading branch information
obliviateandsurrender and Jaybsoni committed Apr 19, 2024
1 parent b1cc066 commit 313de85
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 5 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@
* `qml.transforms.hamiltonian_expand` can now handle multi-term observables with a constant offset.
[(#5414)](https://github.com/PennyLaneAI/pennylane/pull/5414)

* `default.mixed` has improved support for sampling-based measurements with non-numpy interfaces.
[(#5514)](https://github.com/PennyLaneAI/pennylane/pull/5514)

* The `qml.qchem.hf_state` function is upgraded to be compatible with the parity and Bravyi-Kitaev bases.
[(#5472)](https://github.com/PennyLaneAI/pennylane/pull/5472)

Expand Down
8 changes: 3 additions & 5 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,15 +872,13 @@ def sample_basis_states(self, number_of_states, state_probability):
)

shots = self.shots

state_probs = qml.math.unwrap(state_probability)
basis_states = np.arange(number_of_states)
if self._ndim(state_probability) == 2:
# np.random.choice does not support broadcasting as needed here.
return np.array(
[np.random.choice(basis_states, shots, p=prob) for prob in state_probability]
)
return np.array([np.random.choice(basis_states, shots, p=prob) for prob in state_probs])

return np.random.choice(basis_states, shots, p=state_probability)
return np.random.choice(basis_states, shots, p=state_probs)

@staticmethod
def generate_basis_states(num_wires, dtype=np.uint32):
Expand Down
50 changes: 50 additions & 0 deletions tests/devices/test_default_mixed_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,53 @@ def circuit(weights):

grad = qml.grad(circuit)(weights)
assert grad.shape == weights.shape


class TestMeasurements:
"""Tests for measurements with default.mixed"""

@pytest.mark.parametrize(
"measurement",
[
qml.counts(qml.PauliZ(0)),
qml.counts(wires=[0]),
qml.sample(qml.PauliX(0)),
qml.sample(wires=[1]),
],
)
def test_measurements_tf(self, measurement):
"""Test sampling-based measurements work with `default.mixed` for trainable interfaces"""
num_shots = 1024
dev = qml.device("default.mixed", wires=2, shots=num_shots)

@qml.qnode(dev, interface="autograd")
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(measurement)

res = circuit(np.array(0.5))

assert len(res) == 2 if isinstance(measurement, qml.measurements.CountsMP) else num_shots

@pytest.mark.parametrize(
"meas_op",
[qml.PauliX(0), qml.PauliZ(0)],
)
def test_measurement_diff(self, meas_op):
"""Test sequence of single-shot expectation values work for derivatives"""
num_shots = 64
dev = qml.device("default.mixed", shots=[(1, num_shots)], wires=2)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.expval(meas_op)

def cost(angle):
return qml.math.hstack(circuit(angle))

angle = np.array(0.1234)

assert isinstance(qml.jacobian(cost)(angle), np.ndarray)
assert len(cost(angle)) == num_shots
49 changes: 49 additions & 0 deletions tests/devices/test_default_mixed_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,52 @@ def circuit(p):
expected.append(circuit(x_indiv))

assert np.allclose(expected, res)


class TestMeasurements:
"""Tests for measurements with default.mixed"""

@pytest.mark.parametrize(
"measurement",
[
qml.counts(qml.PauliZ(0)),
qml.counts(wires=[0]),
qml.sample(qml.PauliX(0)),
qml.sample(wires=[1]),
],
)
def test_measurements_jax(self, measurement):
"""Test sampling-based measurements work with `default.mixed` for trainable interfaces"""
num_shots = 1024
dev = qml.device("default.mixed", wires=2, shots=num_shots)

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(measurement)

res = circuit(jnp.array(0.5))

assert len(res) == 2 if isinstance(measurement, qml.measurements.CountsMP) else num_shots

@pytest.mark.parametrize(
"meas_op",
[qml.PauliX(0), qml.PauliZ(0)],
)
def test_measurement_diff(self, meas_op):
"""Test sequence of single-shot expectation values work for derivatives"""
num_shots = 64
dev = qml.device("default.mixed", shots=[(1, num_shots)], wires=2)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.expval(meas_op)

def cost(angle):
return qml.math.hstack(circuit(angle))

angle = jnp.array(0.1234)
assert isinstance(jax.jacobian(cost)(angle), jax.Array)
assert len(cost(angle)) == num_shots
54 changes: 54 additions & 0 deletions tests/devices/test_default_mixed_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,57 @@ def circuit(p):

# compare results to results of non-decorated circuit
assert np.allclose(circuit(x), res)


class TestMeasurements:
"""Tests for measurements with default.mixed"""

@pytest.mark.parametrize(
"measurement",
[
qml.counts(qml.PauliZ(0)),
qml.counts(wires=[0]),
qml.sample(qml.PauliX(0)),
qml.sample(wires=[1]),
],
)
def test_measurements_tf(self, measurement):
"""Test sampling-based measurements work with `default.mixed` for trainable interfaces"""
num_shots = 1024
dev = qml.device("default.mixed", wires=2, shots=num_shots)

@qml.qnode(dev, interface="tf")
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(measurement)

res = circuit(tf.Variable(0.5))

assert len(res) == 2 if isinstance(measurement, qml.measurements.CountsMP) else num_shots

@pytest.mark.parametrize(
"meas_op",
[qml.PauliX(0), qml.PauliZ(0)],
)
def test_measurement_diff(self, meas_op):
"""Test sequence of single-shot expectation values work for derivatives"""
num_shots = 64
dev = qml.device("default.mixed", shots=[(1, num_shots)], wires=2)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.expval(meas_op)

def cost(angle):
return qml.math.hstack(circuit(angle))

angle = tf.Variable(0.1234)
with tf.GradientTape(persistent=True) as tape:
res = cost(angle)

assert isinstance(res, tf.Tensor)
assert isinstance(tape.gradient(res, angle), tf.Tensor)
assert isinstance(tape.jacobian(res, angle), tf.Tensor)
assert len(res) == num_shots
51 changes: 51 additions & 0 deletions tests/devices/test_default_mixed_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,54 @@ def circuit(weights):

assert isinstance(weights.grad, torch.Tensor)
assert weights.grad.shape == weights.shape


class TestMeasurements:
"""Tests for measurements with default.mixed"""

@pytest.mark.parametrize(
"measurement",
[
qml.counts(qml.PauliZ(0)),
qml.counts(wires=[0]),
qml.sample(qml.PauliX(0)),
qml.sample(wires=[1]),
],
)
def test_measurements_torch(self, measurement):
"""Test sampling-based measurements work with `default.mixed` for trainable interfaces"""
num_shots = 1024
dev = qml.device("default.mixed", wires=2, shots=num_shots)

@qml.qnode(dev, interface="torch")
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(measurement)

res = circuit(torch.tensor(0.5, requires_grad=True))

assert len(res) == 2 if isinstance(measurement, qml.measurements.CountsMP) else num_shots

@pytest.mark.parametrize(
"meas_op",
[qml.PauliX(0), qml.PauliZ(0)],
)
def test_measurement_diff(self, meas_op):
"""Test sequence of single-shot expectation values work for derivatives"""
num_shots = 64
dev = qml.device("default.mixed", shots=[(1, num_shots)], wires=2)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.expval(meas_op)

def cost(angle):
return qml.math.hstack(circuit(angle))

angle = torch.tensor(0.1234, requires_grad=True)
res = torch.autograd.functional.jacobian(cost, angle)

assert isinstance(res, torch.Tensor)
assert len(res) == num_shots

0 comments on commit 313de85

Please sign in to comment.