Skip to content

Commit

Permalink
Introduce sve function for matrix multiplication (#3348)
Browse files Browse the repository at this point in the history
Summary:

Add inlined SVE assembly routine, used to multiply matrixes with one column within FBgemmFP16 function.

The function is compatible with CPUs having any register resize.
By design, it will always load memory in 128-bits chunks.
Only SVE is required, no need for SVE2.

Some benchmarks:

case m=1:
ARM baseline: 9.1 Gflops
ARM SVE: 17.5 Gflops
=> Over 92% performance uplift

case m=2:
ARM baseline: 16.4 Gflops
ARM SVE: 29.2 Gflops
=> Over 78% performance uplift

case m=3:
ARM baseline: 21 Gflops
ARM SVE: 40 Gflops
=> Over 90% performance uplift

case m=4:
ARM baseline: 24.8 Gflops
ARM SVE: 47.5 Gflops
=> Over 91% performance uplift

case m=5:
ARM baseline: 27.3 Gflops
ARM SVE: 52.5 Gflops
=> Over 92% performance uplift

case m=6:
ARM baseline: 29.3 Gflops
ARM SVE: 56 Gflops
=> Over 91% performance uplift

When the benchmark is let run, for m>=100 the optimized version converges at 58GFlops

Reviewed By: jianyuh

Differential Revision: D65720003
  • Loading branch information
Nicoshev authored and facebook-github-bot committed Nov 16, 2024
1 parent 8812a95 commit 9eeccde
Show file tree
Hide file tree
Showing 10 changed files with 2,665 additions and 8 deletions.
13 changes: 13 additions & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,19 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False):
})
return asm_srcs if not msvc else intrinsics_srcs

def get_fbgemm_inline_sve_srcs(msvc = False, buck = False):
intrinsics_srcs = ["src/FbgemmFP16UKernelsSve128.cc"]

#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
asm_srcs = ["src/FbgemmFP16UKernelsSve128.cc"]
if buck:
return select({
"DEFAULT": asm_srcs,
"ovr_config//compiler:cl": intrinsics_srcs,
"ovr_config//cpu:arm64": intrinsics_srcs,
})
return asm_srcs if not msvc else intrinsics_srcs

def get_fbgemm_autovec_srcs():
return [
"src/EmbeddingSpMDMAutovec.cc",
Expand Down
23 changes: 18 additions & 5 deletions include/fbgemm/FbgemmFPCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace fbgemm {
using partition_array_t = std::array<std::array<std::array<int, 2>, 2>, 121>;
extern partition_array_t partition_avx2;
extern partition_array_t partition_avx512;
extern partition_array_t partition_sve128;

template <typename T>
struct GemmParams {
Expand Down Expand Up @@ -118,7 +119,7 @@ void cblas_gemm_compute(
? simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS
: simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
#else
simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
simd_info<inst_set_t::sve>::WIDTH_32BIT_ELEMS;
(void)kernel_ncol_blocks;
(void)kernels;
#endif
Expand Down Expand Up @@ -182,7 +183,11 @@ void cblas_gemm_compute(
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
#ifdef FBGEMM_USE_REF_KERNEL
ref_kernel<T>(kernel_nrows, &gp, C, m, n, simd_width);
if constexpr (std::is_same<T, float16>::value) {
kernels[kernel_nrows](&gp);
} else {
ref_kernel<T>(kernel_nrows, &gp, C, m, n, simd_width);
}
#else
kernels[kernel_nrows](&gp);
#endif
Expand All @@ -198,7 +203,11 @@ void cblas_gemm_compute(
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
#ifdef FBGEMM_USE_REF_KERNEL
ref_kernel(kernel_nrows, &gp, C, m, n, simd_width);
if constexpr (std::is_same<T, float16>::value) {
kernels[kernel_nrows](&gp);
} else {
ref_kernel(kernel_nrows, &gp, C, m, n, simd_width);
}
#else
kernels[kernel_nrows](&gp);
#endif
Expand All @@ -225,8 +234,12 @@ void cblas_gemm_compute(
gp.ldc = Bp.blockColSize() * sizeof(C[0]);
gp.b_block_cols = 1;
#ifdef FBGEMM_USE_REF_KERNEL
ref_kernel<T>(
kernel_nrows, &gp, c_tmp.data(), 14, 32, simd_width);
if constexpr (std::is_same<T, float16>::value) {
kernels[kernel_nrows](&gp);
} else {
ref_kernel<T>(
kernel_nrows, &gp, c_tmp.data(), 14, 32, simd_width);
}
#else
kernels[kernel_nrows](&gp);
#endif
Expand Down
9 changes: 9 additions & 0 deletions include/fbgemm/SimdUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ struct simd_info<inst_set_t::avx2> {
using vec_reg_t = asmjit::x86::Ymm;
};

template <>
struct simd_info<inst_set_t::sve> {
// Implementation is unrolled to match params used on avx2
static constexpr int WIDTH_BITS = 256;
static constexpr int WIDTH_BYTES = 32;
static constexpr int WIDTH_32BIT_ELEMS = 8;
static constexpr int NUM_VEC_REGS = 32;
};

template <>
struct simd_info<inst_set_t::avx512> {
static constexpr int WIDTH_BITS = 512;
Expand Down
8 changes: 7 additions & 1 deletion include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ enum class inst_set_t {
avx512,
avx512_ymm,
avx512_vnni,
avx512_vnni_ymm
avx512_vnni_ymm,
sve
};

/**
Expand Down Expand Up @@ -147,6 +148,11 @@ FBGEMM_API bool fbgemmHasAvx512VnniSupport();
*/
FBGEMM_API bool fbgemmHasArmNeonSupport();

/**
* @brief Are we running on a ARM SVE supported cpu?
*/
FBGEMM_API bool fbgemmHasArmSveSupport();

/**
* @brief Are we running on a ARM SVE2 supported cpu?
*/
Expand Down
24 changes: 24 additions & 0 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "./FbgemmFP16UKernelsAvx2.h"
#include "./FbgemmFP16UKernelsAvx512.h"
#include "./FbgemmFP16UKernelsAvx512_256.h"
#include "./FbgemmFP16UKernelsSve128.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/FbgemmFPCommon.h"

Expand All @@ -33,6 +34,25 @@ constexpr kernel_array_t<float16> kernel_fp16_avx2 = {
gemmkernel_5x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_6x2_Avx2_fp16_fA0fB0fC0};

constexpr kernel_array_t<float16> kernel_fp16_sve128 = {
nullptr,
#ifdef __aarch64__
gemmkernel_1x2_Sve128_fp16_fA0fB0fC0,
gemmkernel_2x2_Sve128_fp16_fA0fB0fC0,
gemmkernel_3x2_Sve128_fp16_fA0fB0fC0,
gemmkernel_4x2_Sve128_fp16_fA0fB0fC0,
gemmkernel_5x2_Sve128_fp16_fA0fB0fC0,
gemmkernel_6x2_Sve128_fp16_fA0fB0fC0,
#else
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
#endif
};

constexpr kernel_array_t<float16> kernel_fp16_avx512_256 = {
nullptr,
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
Expand Down Expand Up @@ -82,8 +102,12 @@ const isa_descriptor<float16>& getIsaHandlers(inst_set_t isa, float16) {
std::make_tuple(kernel_fp16_avx512, partition_avx512);
static isa_descriptor<float16> avx512_256_descriptor =
std::make_tuple(kernel_fp16_avx512_256, partition_avx512);
static isa_descriptor<float16> sve128_descriptor =
std::make_tuple(kernel_fp16_sve128, partition_sve128);

switch (isa) {
case inst_set_t::sve:
return sve128_descriptor;
case inst_set_t::anyarch:
case inst_set_t::avx2:
return avx2_descriptor;
Expand Down
Loading

0 comments on commit 9eeccde

Please sign in to comment.