diff --git a/.github/workflows/interface-unit-tests.yml b/.github/workflows/interface-unit-tests.yml index 0aa27f71139..f2f9f6b0fd1 100644 --- a/.github/workflows/interface-unit-tests.yml +++ b/.github/workflows/interface-unit-tests.yml @@ -36,7 +36,7 @@ on: Indicate if a lightened version of the CI should be run instead of the entire suite. The lightened version of the CI includes the following changes: - - Only Python 3.9 is tested against, instead of 3.9, 3.10, 3.11 + - Only Python 3.9 is tested against, instead of 3.9, 3.10, 3.11, 3.12 required: false type: boolean default: false @@ -70,10 +70,10 @@ jobs: else cat >python_versions.json <<-EOF { - "default": ["3.9", "3.10", "3.11"], + "default": ["3.9", "3.10", "3.11", "3.12"], "torch-tests": ["3.9", "3.11"], - "tf-tests": ["3.9", "3.10"], - "jax-tests": ["3.9", "3.11"], + "tf-tests": ["3.9", "3.11"], + "jax-tests": ["3.9", "3.12"], "all-interfaces-tests": ["3.9"], "external-libraries-tests": ["3.9"], "qcut-tests": ["3.9"], diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst index c96e6853874..1f2bbfd0544 100644 --- a/doc/development/deprecations.rst +++ b/doc/development/deprecations.rst @@ -9,13 +9,13 @@ deprecations are listed below. Pending deprecations -------------------- -* `qml.transforms.one_qubit_decomposition` and `qml.transforms.two_qubit_decomposition` are deprecated. Instead, - you should use `qml.ops.one_qubit_decomposition` and `qml.ops.two_qubit_decomposition` . +* ``qml.transforms.one_qubit_decomposition`` and ``qml.transforms.two_qubit_decomposition`` are deprecated. Instead, + you should use ``qml.ops.one_qubit_decomposition`` and ``qml.ops.two_qubit_decomposition`` . - Deprecated in v0.34 - Will be removed in v0.35 -* `Observable.return_type` is deprecated. Instead, you should inspect the type +* ``Observable.return_type`` is deprecated. Instead, you should inspect the type of the surrounding measurement process. - Deprecated in v0.34 diff --git a/doc/introduction/inspecting_circuits.rst b/doc/introduction/inspecting_circuits.rst index 5b916f65790..50eac348a88 100644 --- a/doc/introduction/inspecting_circuits.rst +++ b/doc/introduction/inspecting_circuits.rst @@ -136,7 +136,7 @@ During normal execution, the snapshots are ignored: @qml.qnode(dev, interface=None) def circuit(): - qml.Snapshot(measurement=qml.expval(qml.PauliZ(0)) + qml.Snapshot(measurement=qml.expval(qml.PauliZ(0))) qml.Hadamard(wires=0) qml.Snapshot("very_important_state") qml.CNOT(wires=[0, 1]) @@ -149,8 +149,8 @@ results. >>> qml.snapshots(circuit)() {0: 1.0, -'very_important_state': array([0.70710678, 0., 0.70710678, 0.]), -2: array([0.70710678, 0., 0., 0.70710678]), +'very_important_state': array([0.707+0.j, 0.+0.j, 0.707+0.j, 0.+0.j]), +2: array([0.707+0.j, 0.+0.j, 0.+0.j, 0.707+0.j]), 'execution_results': 0.0} Graph representation diff --git a/doc/releases/changelog-0.34.0.md b/doc/releases/changelog-0.34.0.md index 2222e4d6f35..3a8443d11fb 100644 --- a/doc/releases/changelog-0.34.0.md +++ b/doc/releases/changelog-0.34.0.md @@ -368,8 +368,6 @@ * The function ``qml.draw_mpl`` now accept a keyword argument ``fig`` to specify the output figure window. [(#4956)](https://github.com/PennyLaneAI/pennylane/pull/4956) -

Better support for batching

- * `qml.AmplitudeEmbedding` now supports batching when used with Tensorflow. [(#4818)](https://github.com/PennyLaneAI/pennylane/pull/4818) @@ -444,6 +442,9 @@

Other improvements

+* PennyLane now supports Python 3.12. + [(#4985)](https://github.com/PennyLaneAI/pennylane/pull/4985) + * `SampleMeasurement` now has an optional method `process_counts` for computing the measurement results from a counts dictionary. [(#4941)](https://github.com/PennyLaneAI/pennylane/pull/4941/) @@ -606,6 +607,10 @@

Bug fixes 🐛

+* `TransformDispatcher` now stops queuing when performing the transform when applying it to a qfunc. + Only the output of the transform will be queued. + [(#4983)](https://github.com/PennyLaneAI/pennylane/pull/4983) + * `qml.map_wires` now works properly with `qml.cond` and `qml.measure`. [(#4884)](https://github.com/PennyLaneAI/pennylane/pull/4884) @@ -668,6 +673,9 @@ wire order. [(#4781)](https://github.com/PennyLaneAI/pennylane/pull/4781) +* `qml.compile` will now always decompose to `expand_depth`, even if a target basis set is not specified. + [(#4800)](https://github.com/PennyLaneAI/pennylane/pull/4800) + * `qml.transforms.transpile` can now handle measurements that are broadcasted onto all wires. [(#4793)](https://github.com/PennyLaneAI/pennylane/pull/4793) @@ -704,8 +712,14 @@ [(#4951)](https://github.com/PennyLaneAI/pennylane/pull/4951) * `MPLDrawer` does not add the bonus space for classical wires when no classical wires are present. - [(#5987)](https://github.com/PennyLaneAI/pennylane/pull/4987) + [(#4987)](https://github.com/PennyLaneAI/pennylane/pull/4987) + +* `Projector` now works with parameter-broadcasting. + [(#4993)](https://github.com/PennyLaneAI/pennylane/pull/4993) +* The jax-jit interface can now be used with float32 mode. + [(#4990)](https://github.com/PennyLaneAI/pennylane/pull/4990) +

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index dcb1957c439..643755dc5d3 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -52,7 +52,6 @@ VnEntropyMP, Shots, ) -from pennylane.ops.qubit.observables import BasisStateProjector from pennylane.resource import Resources from pennylane.operation import operation_derivative, Operation from pennylane.tape import QuantumTape @@ -1307,14 +1306,6 @@ def marginal_prob(self, prob, wires=None): return self._reshape(prob, flat_shape) def expval(self, observable, shot_range=None, bin_size=None): - if isinstance(observable, BasisStateProjector): - # branch specifically to handle the basis state projector observable - idx = int("".join(str(i) for i in observable.parameters[0]), 2) - probs = self.probability( - wires=observable.wires, shot_range=shot_range, bin_size=bin_size - ) - return probs[idx] - # exact expectation value if self.shots is None: try: @@ -1343,14 +1334,6 @@ def expval(self, observable, shot_range=None, bin_size=None): return np.squeeze(np.mean(samples, axis=axis)) def var(self, observable, shot_range=None, bin_size=None): - if isinstance(observable, BasisStateProjector): - # branch specifically to handle the basis state projector observable - idx = int("".join(str(i) for i in observable.parameters[0]), 2) - probs = self.probability( - wires=observable.wires, shot_range=shot_range, bin_size=bin_size - ) - return probs[idx] - probs[idx] ** 2 - # exact variance value if self.shots is None: try: diff --git a/pennylane/debugging.py b/pennylane/debugging.py index 3a7249dc9ee..feb0639ab5a 100644 --- a/pennylane/debugging.py +++ b/pennylane/debugging.py @@ -81,8 +81,8 @@ def circuit(): >>> qml.snapshots(circuit)() {0: 1.0, - 'very_important_state': array([0.70710678, 0. , 0.70710678, 0. ]), - 2: array([0.70710678, 0. , 0. , 0.70710678]), + 'very_important_state': array([0.70710678+0.j, 0. +0.j, 0.70710678+0.j, 0. +0.j]), + 2: array([0.70710678+0.j, 0. +0.j, 0. +0.j, 0.70710678+0.j]), 'execution_results': 0.0} """ diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 4cdf73bfa0f..c4640873e36 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -138,9 +138,8 @@ def adjoint_state_measurements(tape: QuantumTape) -> (Tuple[QuantumTape], Callab params = tape.get_parameters() complex_data = [qml.math.cast(p, complex) for p in params] tape = tape.bind_new_parameters(complex_data, list(range(len(params)))) - state_tape = qml.tape.QuantumScript( - tape.operations, [qml.measurements.StateMP(wires=tape.wires)] - ) + new_mp = qml.measurements.StateMP(wires=tape.wires) + state_tape = qml.tape.QuantumScript(tape.operations, [new_mp]) return (state_tape,), partial( all_state_postprocessing, measurements=tape.measurements, wire_order=tape.wires ) diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 69913d5b6e2..cd6e4949e86 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -365,7 +365,8 @@ def apply_snapshot(op: qml.Snapshot, state, is_state_batched: bool = False, debu if measurement: snapshot = qml.devices.qubit.measure(measurement, state) else: - snapshot = math.flatten(state) + flat_shape = (math.shape(state)[0], -1) if is_state_batched else (-1,) + snapshot = math.cast(math.reshape(state, flat_shape), complex) if op.tag: debugger.snapshots[op.tag] = snapshot else: diff --git a/pennylane/interfaces/jax_jit.py b/pennylane/interfaces/jax_jit.py index 5b064af7196..d5c2523ed62 100644 --- a/pennylane/interfaces/jax_jit.py +++ b/pennylane/interfaces/jax_jit.py @@ -46,7 +46,6 @@ from .jax import _NonPytreeWrapper -dtype = jnp.float64 Zero = jax.custom_derivatives.SymbolicZero @@ -71,18 +70,28 @@ def _set_parameters_on_copy(tapes, params): return tuple(t.bind_new_parameters(a, list(range(len(a)))) for t, a in zip(tapes, params)) +def _jax_dtype(m_type): + if m_type == int: + return jnp.int64 if jax.config.jax_enable_x64 else jnp.int32 + if m_type == float: + return jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 + if m_type == complex: + return jnp.complex128 if jax.config.jax_enable_x64 else jnp.complex64 + return jnp.dtype(m_type) + + def _result_shape_dtype_struct(tape: "qml.tape.QuantumScript", device: "qml.Device"): """Auxiliary function for creating the shape and dtype object structure given a tape.""" shape = tape.shape(device) if len(tape.measurements) == 1: - tape_dtype = jnp.dtype(tape.numeric_type) + m_dtype = _jax_dtype(tape.measurements[0].numeric_type) if tape.shots.has_partitioned_shots: - return tuple(jax.ShapeDtypeStruct(s, tape_dtype) for s in shape) - return jax.ShapeDtypeStruct(tuple(shape), tape_dtype) + return tuple(jax.ShapeDtypeStruct(s, m_dtype) for s in shape) + return jax.ShapeDtypeStruct(tuple(shape), m_dtype) - tape_dtype = tuple(jnp.dtype(elem) for elem in tape.numeric_type) + tape_dtype = tuple(_jax_dtype(m.numeric_type) for m in tape.measurements) if tape.shots.has_partitioned_shots: return tuple( tuple(jax.ShapeDtypeStruct(tuple(s), d) for s, d in zip(si, tape_dtype)) for si in shape @@ -129,7 +138,7 @@ def _execute_wrapper(params, tapes, execute_fn, _, device) -> ResultBatch: def pure_callback_wrapper(p): new_tapes = _set_parameters_on_copy(tapes.vals, p) - res = tuple(execute_fn(new_tapes)) + res = tuple(_to_jax(execute_fn(new_tapes))) # When executed under `jax.vmap` the `result_shapes_dtypes` will contain # the shape without the vmap dimensions, while the function here will be # executed with objects containing the vmap dimensions. So res[i].ndim @@ -143,7 +152,7 @@ def pure_callback_wrapper(p): return jax.tree_map(lambda r, s: r.T if r.ndim > s.ndim else r, res, shape_dtype_structs) out = jax.pure_callback(pure_callback_wrapper, shape_dtype_structs, params, vectorized=True) - return _to_jax(out) + return out def _execute_and_compute_jvp(tapes, execute_fn, jpc, device, primals, tangents): @@ -165,7 +174,7 @@ def _execute_and_compute_jvp(tapes, execute_fn, jpc, device, primals, tangents): def wrapper(inner_params): new_tapes = _set_parameters_on_copy(tapes.vals, inner_params) - return jpc.execute_and_compute_jacobian(new_tapes) + return _to_jax(jpc.execute_and_compute_jacobian(new_tapes)) res_struct = tuple(_result_shape_dtype_struct(t, device) for t in tapes.vals) jac_struct = tuple(_jac_shape_dtype_struct(t, device) for t in tapes.vals) @@ -173,7 +182,7 @@ def wrapper(inner_params): jvps = _compute_jvps(jacobians, tangents_trainable, tapes.vals) - return _to_jax(results), _to_jax(jvps) + return results, jvps def _vjp_fwd(params, tapes, execute_fn, jpc, device): @@ -186,7 +195,7 @@ def _vjp_bwd(tapes, execute_fn, jpc, device, params, dy): def wrapper(inner_params, inner_dy): new_tapes = _set_parameters_on_copy(tapes.vals, inner_params) - return tuple(jpc.compute_vjp(new_tapes, inner_dy)) + return _to_jax(jpc.compute_vjp(new_tapes, inner_dy)) vjp_shape = _pytree_shape_dtype_struct(params) return (jax.pure_callback(wrapper, vjp_shape, params, dy),) diff --git a/pennylane/measurements/expval.py b/pennylane/measurements/expval.py index 3f935dc1aa1..74bf3c5b15b 100644 --- a/pennylane/measurements/expval.py +++ b/pennylane/measurements/expval.py @@ -19,7 +19,6 @@ import pennylane as qml from pennylane.operation import Operator -from pennylane.ops.qubit.observables import BasisStateProjector from pennylane.wires import Wires from .measurements import Expectation, SampleMeasurement, StateMeasurement @@ -107,15 +106,6 @@ def process_samples( shot_range: Tuple[int] = None, bin_size: int = None, ): - if isinstance(self.obs, BasisStateProjector): - # branch specifically to handle the basis state projector observable - idx = int("".join(str(i) for i in self.obs.parameters[0]), 2) - with qml.queuing.QueuingManager.stop_recording(): - probs = qml.probs(wires=self.wires).process_samples( - samples=samples, wire_order=wire_order, shot_range=shot_range, bin_size=bin_size - ) - return probs[idx] - # estimate the ev op = self.mv if self.mv is not None else self.obs with qml.queuing.QueuingManager.stop_recording(): @@ -130,15 +120,6 @@ def process_samples( return qml.math.squeeze(qml.math.mean(samples, axis=axis)) def process_state(self, state: Sequence[complex], wire_order: Wires): - if isinstance(self.obs, BasisStateProjector): - # branch specifically to handle the basis state projector observable - idx = int("".join(str(i) for i in self.obs.parameters[0]), 2) - with qml.queuing.QueuingManager.stop_recording(): - probs = qml.probs(wires=self.wires).process_state( - state=state, wire_order=wire_order - ) - return probs[idx] - # This also covers statistics for mid-circuit measurements manipulated using # arithmetic operators eigvals = qml.math.asarray(self.eigvals(), dtype="float64") diff --git a/pennylane/measurements/shots.py b/pennylane/measurements/shots.py index f5220f77e2f..448f796e732 100644 --- a/pennylane/measurements/shots.py +++ b/pennylane/measurements/shots.py @@ -114,6 +114,13 @@ class Shots: >>> shots.total_shots, shots.shot_vector (1210, (ShotCopies(10 shots x 1), ShotCopies(100 shots x 4), ShotCopies(200 shots x 4))) + Example constructing a Shots instance by multiplying an existing one by an int or float: + + >>> Shots(100) * 2 + Shots(total_shots=200, shot_vector=(ShotCopies(200 shots x 1),)) + >>> Shots([7, (100, 2)]) * 1.5 + Shots(total_shots=310, shot_vector=(ShotCopies(10 shots x 1), ShotCopies(150 shots x 2))) + One should also note that specifying a single tuple of length 2 is considered two different shot values, and *not* a tuple-pair representing shots and copies to avoid special behaviour depending on the iterable type: diff --git a/pennylane/measurements/var.py b/pennylane/measurements/var.py index 131e153ab71..798a47d6c58 100644 --- a/pennylane/measurements/var.py +++ b/pennylane/measurements/var.py @@ -20,7 +20,6 @@ import pennylane as qml from pennylane.operation import Operator -from pennylane.ops.qubit.observables import BasisStateProjector from pennylane.wires import Wires from .measurements import SampleMeasurement, StateMeasurement, Variance @@ -107,16 +106,6 @@ def process_samples( shot_range: Tuple[int] = None, bin_size: int = None, ): - if isinstance(self.obs, BasisStateProjector): - # branch specifically to handle the basis state projector observable - idx = int("".join(str(i) for i in self.obs.parameters[0]), 2) - # we use ``self.wires`` instead of ``self.obs`` because the observable was - # already applied before the sampling - probs = qml.probs(wires=self.wires).process_samples( - samples=samples, wire_order=wire_order, shot_range=shot_range, bin_size=bin_size - ) - return probs[idx] - probs[idx] ** 2 - # estimate the variance op = self.mv if self.mv is not None else self.obs with qml.queuing.QueuingManager.stop_recording(): @@ -131,17 +120,6 @@ def process_samples( return qml.math.squeeze(qml.math.var(samples, axis=axis)) def process_state(self, state: Sequence[complex], wire_order: Wires): - if isinstance(self.obs, BasisStateProjector): - # branch specifically to handle the basis state projector observable - idx = int("".join(str(i) for i in self.obs.parameters[0]), 2) - # we use ``self.wires`` instead of ``self.obs`` because the observable was - # already applied to the state - with qml.queuing.QueuingManager.stop_recording(): - probs = qml.probs(wires=self.wires).process_state( - state=state, wire_order=wire_order - ) - return probs[idx] - probs[idx] ** 2 - # This also covers statistics for mid-circuit measurements manipulated using # arithmetic operators eigvals = qml.math.asarray(self.eigvals(), dtype="float64") diff --git a/pennylane/transforms/compile.py b/pennylane/transforms/compile.py index 0635354a3a2..2e7b48ae88f 100644 --- a/pennylane/transforms/compile.py +++ b/pennylane/transforms/compile.py @@ -52,14 +52,15 @@ def compile( tape and/or quantum function transforms to apply. basis_set (list[str]): A list of basis gates. When expanding the tape, expansion will continue until gates in the specific set are - reached. If no basis set is specified, no expansion will be done. + reached. If no basis set is specified, a default of + ``pennylane.ops.__all__`` will be used. This decomposes templates and + operator arithmetic. num_passes (int): The number of times to apply the set of transforms in ``pipeline``. The default is to perform each transform once; however, doing so may produce a new circuit where applying the set of transforms again may yield further improvement, so the number of such passes can be adjusted. - expand_depth (int): When ``basis_set`` is specified, the depth to use - for tape expansion into the basis gates. + expand_depth (int): The depth to use for tape expansion into the basis gates. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The compiled circuit. The output type is explained in :func:`qml.transform `. @@ -181,17 +182,12 @@ def qfunc(x, y, z): # don't queue anything as a result of the expansion or transform pipeline with QueuingManager.stop_recording(): - if basis_set is not None: - expanded_tape = tape.expand( - depth=expand_depth, stop_at=lambda obj: obj.name in basis_set - ) - else: - # Expands out anything that is not a single operation (i.e., the templates) - # expand barriers when `only_visual=True` - def stop_at(obj): - return (obj.name in all_ops) and (not getattr(obj, "only_visual", False)) - - expanded_tape = tape.expand(stop_at=stop_at) + basis_set = basis_set or all_ops + + def stop_at(obj): + return obj.name in basis_set and (not getattr(obj, "only_visual", False)) + + expanded_tape = tape.expand(depth=expand_depth, stop_at=stop_at) # Apply the full set of compilation transforms num_passes times for _ in range(num_passes): diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index dba4f9578c7..3eea403e87e 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -238,7 +238,8 @@ def qfunc_transformed(*args, **kwargs): qfunc_output = qfunc(*args, **kwargs) tape = qml.tape.QuantumScript.from_queue(q) - transformed_tapes, processing_fn = self._transform(tape, *targs, **tkwargs) + with qml.QueuingManager.stop_recording(): + transformed_tapes, processing_fn = self._transform(tape, *targs, **tkwargs) if len(transformed_tapes) != 1: raise TransformError( diff --git a/setup.py b/setup.py index 5d203ee2d14..91577010f5b 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3 :: Only", "Topic :: Scientific/Engineering :: Physics", ] diff --git a/tests/devices/qubit/test_apply_operation.py b/tests/devices/qubit/test_apply_operation.py index 18929bc1bc3..9135ec0a192 100644 --- a/tests/devices/qubit/test_apply_operation.py +++ b/tests/devices/qubit/test_apply_operation.py @@ -624,6 +624,17 @@ def test_measurement(self, ml_framework): assert debugger.snapshots[0].shape == () assert debugger.snapshots[0] == qml.devices.qubit.measure(measurement, initial_state) + def test_batched_state(self, ml_framework): + """Test that batched states create batched snapshots.""" + initial_state = qml.math.asarray([[1.0, 0.0], [0.0, 0.1]], like=ml_framework) + debugger = self.Debugger() + new_state = apply_operation( + qml.Snapshot(), initial_state, is_state_batched=True, debugger=debugger + ) + assert new_state.shape == initial_state.shape + assert set(debugger.snapshots) == {0} + assert np.array_equal(debugger.snapshots[0], initial_state) + @pytest.mark.parametrize("method", methods) class TestRXCalcGrad: diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index 006db9f4a82..4fd018a0ee3 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -298,12 +298,12 @@ def test_finite_shots_with_state(self, measurements): validate_measurements(tape, lambda obj: True) -class TestExpandFnTransformations: - """Tests for the behavior of the `expand_fn` helper.""" +class TestDecomposeTransformations: + """Tests for the behavior of the `decompose` helper.""" @pytest.mark.parametrize("shots", [None, 100]) def test_decompose_expand_unsupported_op(self, shots): - """Test that expand_fn expands the tape when unsupported operators are present""" + """Test that decompose expands the tape when unsupported operators are present""" ops = [qml.Hadamard(0), NoMatOp(1), qml.RZ(0.123, wires=1)] measurements = [qml.expval(qml.PauliZ(0)), qml.probs()] tape = QuantumScript(ops=ops, measurements=measurements, shots=shots) diff --git a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py index 0f830f4de3d..12b520db2b6 100644 --- a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py @@ -2903,3 +2903,60 @@ def circuit(a, b): else: assert np.allclose(jac[0], expected[0], atol=tol) assert np.allclose(jac[1], expected[1], atol=tol) + + +class TestSinglePrecision: + @pytest.mark.parametrize("diff_method", ("adjoint", "parameter-shift")) + def test_float32_return(self, diff_method): + """Test that jax jit works when float64 mode is disabled.""" + jax.config.update("jax_enable_x64", False) + + try: + + @jax.jit + @qml.qnode(qml.device("default.qubit"), diff_method=diff_method) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + grad = jax.grad(circuit)(jax.numpy.array(0.1)) + assert qml.math.allclose(grad, -np.sin(0.1)) + finally: + jax.config.update("jax_enable_x64", True) + + @pytest.mark.parametrize("diff_method", ("adjoint", "finite-diff")) + def test_complex64_return(self, diff_method): + """Test that jax jit works with differentiating the state.""" + jax.config.update("jax_enable_x64", False) + + try: + tol = 2e-2 if diff_method == "finite-diff" else 1e-6 + + @jax.jit + @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method) + def circuit(x): + qml.RX(x, wires=0) + return qml.state() + + j = jax.jacobian(circuit, holomorphic=True)(jax.numpy.array(0.1 + 0j)) + assert qml.math.allclose(j, [-np.sin(0.05) / 2, -np.cos(0.05) / 2 * 1j], atol=tol) + + finally: + jax.config.update("jax_enable_x64", True) + + def test_int32_return(self): + """Test that jax jit forward execution works with samples and int32""" + + jax.config.update("jax_enable_x64", False) + + try: + + @jax.jit + @qml.qnode(qml.device("default.qubit", shots=10), diff_method=qml.gradients.param_shift) + def circuit(x): + qml.RX(x, wires=0) + return qml.sample(wires=0) + + _ = circuit(jax.numpy.array(0.1)) + finally: + jax.config.update("jax_enable_x64", True) diff --git a/tests/interfaces/test_tensorflow_qnode.py b/tests/interfaces/test_tensorflow_qnode.py index da68bbf7f2f..6bcf28139dd 100644 --- a/tests/interfaces/test_tensorflow_qnode.py +++ b/tests/interfaces/test_tensorflow_qnode.py @@ -527,15 +527,14 @@ def circuit(weights): # execute with shots=100 circuit(weights, shots=100) # pylint: disable=unexpected-keyword-arg - spy.assert_called() + spy.assert_called_once() assert spy.spy_return.shape == (100,) # device state has been unaffected assert dev.shots is None - spy = mocker.spy(dev, "sample") res = circuit(weights) assert np.allclose(res, -np.cos(a) * np.sin(b), atol=tol, rtol=0) - spy.assert_not_called() + spy.assert_called_once() @pytest.mark.xfail(reason="TODO: shot-vector support for param shift") def test_gradient_integration(self, interface): diff --git a/tests/ops/op_math/test_composite.py b/tests/ops/op_math/test_composite.py index 7ffd93e5d9d..0b07a82db27 100644 --- a/tests/ops/op_math/test_composite.py +++ b/tests/ops/op_math/test_composite.py @@ -72,9 +72,7 @@ class TestConstruction: def test_direct_initialization_fails(self): """Test directly initializing a CompositeOp fails""" - with pytest.raises( - TypeError, match="Can't instantiate abstract class CompositeOp with abstract methods" - ): + with pytest.raises(TypeError, match="Can't instantiate abstract class CompositeOp"): _ = CompositeOp(*self.simple_operands) # pylint:disable=abstract-class-instantiated def test_raise_error_fewer_than_2_operands(self): diff --git a/tests/ops/qubit/test_observables.py b/tests/ops/qubit/test_observables.py index 525567a91d0..6b5d2a6f9c5 100644 --- a/tests/ops/qubit/test_observables.py +++ b/tests/ops/qubit/test_observables.py @@ -637,6 +637,19 @@ def test_matrix_representation(self, basis_state, expected, n_wires, tol): assert np.allclose(res_dynamic, expected, atol=tol) assert np.allclose(res_static, expected, atol=tol) + @pytest.mark.parametrize( + "dev", (qml.device("default.qubit"), qml.device("default.qubit.legacy", wires=1)) + ) + def test_integration_batched_state(self, dev): + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.Projector([0], wires=0)) + + x = np.array([0.4, 0.8, 1.2]) + res = circuit(x) + assert qml.math.allclose(res, np.cos(x / 2) ** 2) + class TestStateVectorProjector: """Tests for state vector projector observable.""" diff --git a/tests/transforms/test_compile.py b/tests/transforms/test_compile.py index 31c54a277be..b22e3e56283 100644 --- a/tests/transforms/test_compile.py +++ b/tests/transforms/test_compile.py @@ -302,9 +302,12 @@ def test_compile_template(self): # Push commuting gates to the right and merging rotations gives a circuit # with alternating RX and CNOT gates + # pylint: disable=expression-not-assigned def qfunc(x, params): qml.templates.AngleEmbedding(x, wires=range(3)) - qml.templates.BasicEntanglerLayers(params, wires=range(3)) + qml.adjoint( + qml.adjoint(qml.templates.BasicEntanglerLayers(params, wires=range(3))) + ) ** 2 return qml.expval(qml.PauliZ(wires=2)) dev = qml.device("default.qubit", wires=3) @@ -321,7 +324,7 @@ def qfunc(x, params): transformed_result = transformed_qnode(x, params) assert np.allclose(original_result, transformed_result) - names_expected = ["RX", "CNOT"] * 6 + names_expected = ["RX", "CNOT"] * 12 wires_expected = [ Wires(0), Wires([0, 1]), @@ -329,7 +332,7 @@ def qfunc(x, params): Wires([1, 2]), Wires(2), Wires([2, 0]), - ] * 2 + ] * 4 compare_operation_lists(transformed_qnode.qtape.operations, names_expected, wires_expected) diff --git a/tests/transforms/test_experimental/test_transform_dispatcher.py b/tests/transforms/test_experimental/test_transform_dispatcher.py index ac5fa93b33a..8e77553e069 100644 --- a/tests/transforms/test_experimental/test_transform_dispatcher.py +++ b/tests/transforms/test_experimental/test_transform_dispatcher.py @@ -100,6 +100,7 @@ def first_valid_transform( """A valid transform.""" tape = tape.copy() tape._ops.pop(index) # pylint:disable=protected-access + _ = (qml.PauliX(0), qml.S(0)) return [tape], lambda x: x