Skip to content

Commit

Permalink
Add a GPU implementation of lax.linalg.eig.
Browse files Browse the repository at this point in the history
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 7, 2024
1 parent 1a544b6 commit 1dd1a7b
Show file tree
Hide file tree
Showing 17 changed files with 1,094 additions and 57 deletions.
11 changes: 11 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,3 +1953,14 @@ def _update_garbage_collection_guard(state, key, val):
),
include_in_jit_key=True,
)

gpu_use_magma = enum_state(
name='jax_gpu_use_magma',
enum_values=['off', 'on', 'auto'],
default='off',
help=(
'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. '
'See the documentation for lax.linalg.eig for more details about how '
'to use this feature.'
),
)
122 changes: 113 additions & 9 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,29 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:


def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True) -> list[Array]:
compute_right_eigenvectors: bool = True,
use_magma: bool | None = None) -> list[Array]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
the default implementation calls LAPACK directly on the host CPU, but an
experimental GPU implementation using `MAGMA <https://icl.utk.edu/magma/>`_
is also available. The MAGMA implementation is typically slower than the
equivalent LAPACK implementation for small matrices (less than about 2048),
but it may perform better for larger matrices.
To enable the MAGMA implementation, you must install MAGMA yourself (there
are Debian and conda-forge packages, or you can build from source). Then set
the `jax_gpu_use_magma` configuration variable to `"on"`:
.. code-block:: python
jax.config.update('jax_gpu_use_magma', 'on')
JAX will try to ``dlopen`` the installed MAGMA shared library, raising an
error if it is not found. To explicitly specify the path to the MAGMA
library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full
installation path.
Args:
x: A batch of square matrices with shape ``[..., n, n]``.
Expand All @@ -142,7 +161,8 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
for that batch element.
"""
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=use_magma)


def eigh(
Expand Down Expand Up @@ -678,12 +698,14 @@ def _symmetric_product_jax_fn(a, c, *, alpha, beta):

# Asymmetric eigendecomposition

def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors,
use_magma):
return dispatch.apply_primitive(
eig_p,
operand,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=use_magma,
)

def eig_lower(*args, **kw):
Expand All @@ -692,7 +714,8 @@ def eig_lower(*args, **kw):
"If your matrix is symmetric or Hermitian, you should use eigh instead.")

def eig_abstract_eval(operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
compute_right_eigenvectors, use_magma):
del use_magma # unused
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
Expand All @@ -716,7 +739,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
return tuple(output)

def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
compute_right_eigenvectors, use_magma):
del use_magma # unused
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
Expand Down Expand Up @@ -763,18 +787,94 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
return output


def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors,
compute_right_eigenvectors, use_magma):
gpu_solver.initialize_hybrid_kernels()
dtype = x.dtype
is_real = dtype == np.float32 or dtype == np.float64
if is_real:
target_name = f"{target_name_prefix}hybrid_eig_real"
complex_dtype = np.complex64 if dtype == np.float32 else np.complex128
else:
target_name = f"{target_name_prefix}hybrid_eig_comp"
assert dtype == np.complex64 or dtype == np.complex128
complex_dtype = dtype

batch_dims = x.shape[:-2]
n, m = x.shape[-2:]
assert n == m
num_batch_dims = len(batch_dims)

layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims)
out_types = [
api.ShapeDtypeStruct(batch_dims + (n,), dtype),
api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
api.ShapeDtypeStruct(batch_dims, np.int32),
]
out_layouts = [None, layout, layout, None]
if is_real:
out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types
out_layouts = [None] + out_layouts

magma = config.gpu_use_magma.value
if use_magma is not None:
magma = "on" if use_magma else "off"
fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout],
output_layouts=out_layouts)
*w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors,
right=compute_right_eigenvectors)
if is_real:
assert len(w) == 2
w = lax.complex(*w)
else:
assert len(w) == 1
w = w[0]
ok = lax.eq(info, lax.zeros_like_array(info))
ok = _broadcast_to(ok[..., None], w.shape)
w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j))
ok = _broadcast_to(ok[..., None], x.shape)
output = [w]
if compute_left_eigenvectors:
vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j))
output.append(vl)
if compute_right_eigenvectors:
vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j))
output.append(vr)
return output


def _eig_gpu_lowering(target_name_prefix, ctx, operand, *,
compute_left_eigenvectors, compute_right_eigenvectors,
use_magma):
if ctx.is_forward_compat():
raise NotImplementedError(
"Export of nonsymmetric eigendecomposition on GPU is not supported "
"because of forward compatibility. The "
"'jax_export_ignore_forward_compatibility' configuration option can be "
"used to disable this check.")
rule = mlir.lower_fun(partial(
_eig_gpu_impl, target_name_prefix,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=use_magma), multiple_results=True)
return rule(ctx, operand)


def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
compute_right_eigenvectors):
compute_right_eigenvectors, use_magma):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)

return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors),
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=use_magma),
(0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors))

def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
compute_right_eigenvectors):
compute_right_eigenvectors, use_magma):
del use_magma # unused
if compute_left_eigenvectors or compute_right_eigenvectors:
raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only '
Expand All @@ -793,6 +893,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
eig_p.def_abstract_eval(eig_abstract_eval)
mlir.register_lowering(eig_p, eig_lower)
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'),
platform='cuda')
mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'),
platform='rocm')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule

Expand Down
4 changes: 3 additions & 1 deletion jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
- This differs from :func:`numpy.linalg.eig` in that the return type of
:func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128
for 64-bit input.
- At present, non-symmetric eigendecomposition is only implemented on the CPU backend.
- At present, non-symmetric eigendecomposition is only implemented on the CPU and
GPU backends. For more details about the GPU implementation, see the
documentation for :func:`jax.lax.linalg.eig`.
See also:
- :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix.
Expand Down
28 changes: 0 additions & 28 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric<ffi::DataType::F64>;
template struct EigenvalueDecompositionHermitian<ffi::DataType::C64>;
template struct EigenvalueDecompositionHermitian<ffi::DataType::C128>;

// LAPACK uses a packed representation to represent a mixture of real
// eigenvectors and complex conjugate pairs. This helper unpacks the
// representation into regular complex matrices.
template <typename T>
static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag,
const T* packed, std::complex<T>* unpacked) {
for (int j = 0; j < n;) {
if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
// Real values in each row without imaginary part
// Second row of the imaginary part is not provided
for (int i = 0; i < n; ++i) {
unpacked[j * n + i] = {packed[j * n + i], 0.};
}
++j;
} else {
// Complex values where the real part is in the jth row
// and the imaginary part is in the next row (j + 1)
for (int i = 0; i < n; ++i) {
const T real_part = packed[j * n + i];
const T imag_part = packed[(j + 1) * n + i];
unpacked[j * n + i] = {real_part, imag_part};
unpacked[(j + 1) * n + i] = {real_part, -imag_part};
}
j += 2;
}
}
}

// lapack geev

template <typename T>
Expand Down
29 changes: 29 additions & 0 deletions jaxlib/cpu/lapack_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_
#define JAXLIB_CPU_LAPACK_KERNELS_H_

#include <complex>
#include <cstdint>
#include <optional>
#include <type_traits>
Expand Down Expand Up @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian {

// lapack geev

// LAPACK uses a packed representation to represent a mixture of real
// eigenvectors and complex conjugate pairs. This helper unpacks the
// representation into regular complex matrices.
template <typename T, typename Int=lapack_int>
static void UnpackEigenvectors(Int n, const T* eigenvals_imag,
const T* packed, std::complex<T>* unpacked) {
for (int j = 0; j < n;) {
if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
// Real values in each row without imaginary part
// Second row of the imaginary part is not provided
for (int i = 0; i < n; ++i) {
unpacked[j * n + i] = {packed[j * n + i], 0.};
}
++j;
} else {
// Complex values where the real part is in the jth row
// and the imaginary part is in the next row (j + 1)
for (int i = 0; i < n; ++i) {
const T real_part = packed[j * n + i];
const T imag_part = packed[(j + 1) * n + i];
unpacked[j * n + i] = {real_part, imag_part};
unpacked[(j + 1) * n + i] = {real_part, -imag_part};
}
j += 2;
}
}
}

template <typename T>
struct RealGeev {
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a,
Expand Down
50 changes: 50 additions & 0 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,55 @@ pybind_extension(
],
)

cc_library(
name = "cuda_hybrid_kernels",
srcs = ["//jaxlib/gpu:hybrid_kernels.cc"],
hdrs = ["//jaxlib/gpu:hybrid_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib/cpu:lapack_kernels",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@xla//xla/ffi/api:ffi",
],
)

pybind_extension(
name = "_hybrid",
srcs = ["//jaxlib/gpu:hybrid.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
module_name = "_hybrid",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_hybrid_kernels",
":cuda_vendor",
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cpu:lapack_kernels",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cudart",
],
)

cc_library(
name = "cuda_gpu_kernels",
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
Expand Down Expand Up @@ -633,6 +682,7 @@ py_library(
name = "cuda_gpu_support",
deps = [
":_blas",
":_hybrid",
":_linalg",
":_prng",
":_rnn",
Expand Down
3 changes: 3 additions & 0 deletions jaxlib/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ exports_files(srcs = [
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",
"hybrid.cc",
"hybrid_kernels.cc",
"hybrid_kernels.h",
"linalg.cc",
"linalg_kernels.cc",
"linalg_kernels.cu.cc",
Expand Down
Loading

0 comments on commit 1dd1a7b

Please sign in to comment.