Skip to content

Commit

Permalink
Block-wise FP8 matmul (pytorch#2780)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2780

Introduce a CUTLASS-based matmul for block-scaled fp8 tensors.

This is based on the regular ("slow" accum) fp8 matmul in CUTLASS, with its fp8 accumulator class changed to do a fused multiply-and-add instead of a regular add into the global accumulator. This required changes throughout the stack, which is why I ended up copying sizeable chunks of CUTLASS into this diff.

Reviewed By: ipiszy, jiawenliu64

Differential Revision: D57965065

fbshipit-source-id: 0b92b2ac1b3c687f23e820ea05255149dac686dc
  • Loading branch information
lw authored and facebook-github-bot committed Jun 27, 2024
1 parent 24140d5 commit 5a5b0e6
Show file tree
Hide file tree
Showing 4 changed files with 1,882 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ at::Tensor f8f8bf16_blockwise_impl(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale) {
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
// Check that inputs are valid.
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
Expand All @@ -74,13 +77,13 @@ at::Tensor f8f8bf16_blockwise_impl(
int StrideE = N;

// For now hardcode block size.
int ScaleBlockM = 256;
int ScaleBlockN = 256;
int ScaleBlockK = 256;
TORCH_CHECK(block_m == 256);
TORCH_CHECK(block_n == 256);
TORCH_CHECK(block_k == 256);

int ScaleStrideAM = K / ScaleBlockK;
int ScaleStrideAM = K / block_k;
int ScaleStrideAK = 1;
int ScaleStrideBN = K / ScaleBlockK;
int ScaleStrideBN = K / block_k;
int ScaleStrideBK = 1;

using ADataType = ck::f8_t;
Expand Down Expand Up @@ -186,9 +189,9 @@ at::Tensor f8f8bf16_blockwise_impl(
StrideE,
reinterpret_cast<AScaleDataType*>(x_scale.data_ptr()),
reinterpret_cast<BScaleDataType*>(w_scale.data_ptr()),
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
block_m,
block_n,
block_k,
ScaleStrideAM,
ScaleStrideAK,
ScaleStrideBN,
Expand Down Expand Up @@ -239,7 +242,10 @@ at::Tensor f8f8bf16_blockwise(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale) {
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
// Check that input datatypes are valid.
TORCH_CHECK(
(XQ.dtype() == at::kFloat8_e4m3fnuz) &&
Expand All @@ -252,24 +258,24 @@ at::Tensor f8f8bf16_blockwise(
if (pad) {
if (kernel == BlockKernelMode::Small) {
return f8f8bf16_blockwise_impl<128, 32, 128, 128, 1, 2, true>(
XQ, WQ, x_scale, w_scale);
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
} else if (kernel == BlockKernelMode::Large) {
return f8f8bf16_blockwise_impl<256, 256, 128, 64, 4, 2, true>(
XQ, WQ, x_scale, w_scale);
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
} else {
return f8f8bf16_blockwise_impl<256, 128, 128, 64, 2, 2, true>(
XQ, WQ, x_scale, w_scale);
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
}
} else {
if (kernel == BlockKernelMode::Small) {
return f8f8bf16_blockwise_impl<128, 32, 128, 128, 1, 2, true>(
XQ, WQ, x_scale, w_scale);
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
} else if (kernel == BlockKernelMode::Large) {
return f8f8bf16_blockwise_impl<256, 256, 128, 64, 4, 2, false>(
XQ, WQ, x_scale, w_scale);
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
} else {
return f8f8bf16_blockwise_impl<256, 128, 128, 64, 2, 2, false>(
XQ, WQ, x_scale, w_scale);
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
}
}
}
Expand Down
251 changes: 251 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>

#include "fp8_blockwise_cutlass_helpers.h"

// Each block handles a single batch and head

// Each warp handles separate D dimension.
Expand Down Expand Up @@ -302,6 +304,14 @@ class LinearCombinationOnDevice {

} // namespace cutlass::epilogue::thread

namespace {

int64_t ceil_div(int64_t a, int64_t b) {
return (a + b - 1) / b;
}

} // namespace

namespace fbgemm_gpu {

template <int TB_M, int TB_N, int TB_K, int W_M, int W_N, int W_K>
Expand Down Expand Up @@ -1438,6 +1448,236 @@ at::Tensor f8f8bf16_rowwise(
}
}

// Cutlass blockwise kernel
template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K>
at::Tensor f8f8bf16_blockwise_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
TORCH_CHECK(XQ.dim() == 2);
TORCH_CHECK(WQ.dim() == 2);
int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);
TORCH_CHECK(WQ.size(1) == K);
TORCH_CHECK(XQ.stride(0) == K);
TORCH_CHECK(XQ.stride(1) == 1);
TORCH_CHECK(WQ.stride(0) == K);
TORCH_CHECK(WQ.stride(1) == 1);

TORCH_CHECK(block_m % TB_N == 0);
TORCH_CHECK(block_n % TB_M == 0);
TORCH_CHECK(block_k % TB_K == 0);

TORCH_CHECK(x_scale.dim() == 2);
TORCH_CHECK(w_scale.dim() == 2);
TORCH_CHECK(x_scale.size(0) == ceil_div(M, block_m));
TORCH_CHECK(x_scale.size(1) == ceil_div(K, block_k));
TORCH_CHECK(w_scale.size(0) == ceil_div(N, block_n));
TORCH_CHECK(w_scale.size(1) == ceil_div(K, block_k));
TORCH_CHECK(x_scale.stride(0) == ceil_div(K, block_k));
TORCH_CHECK(x_scale.stride(1) == 1);
TORCH_CHECK(w_scale.stride(0) == ceil_div(K, block_k));
TORCH_CHECK(w_scale.stride(1) == 1);

TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn);
TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn);
TORCH_CHECK(XQ.is_cuda());
TORCH_CHECK(WQ.is_cuda());
TORCH_CHECK(XQ.device().index() == WQ.device().index());
TORCH_CHECK(x_scale.dtype() == at::kFloat);
TORCH_CHECK(w_scale.dtype() == at::kFloat);
TORCH_CHECK(x_scale.is_cuda());
TORCH_CHECK(w_scale.is_cuda());
TORCH_CHECK(x_scale.device().index() == XQ.device().index());
TORCH_CHECK(w_scale.device().index() == XQ.device().index());

auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::float_e4m3_t;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(ElementInputA);

using ElementInputB = cutlass::float_e4m3_t;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(ElementInputB);

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput = 16 / sizeof(ElementOutput);

using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp;

using MainLoopSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaling;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementInputB,
LayoutInputB,
AlignmentInputB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainLoopSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideD;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(N, K, cute::Int<1>{}));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(M, K, cute::Int<1>{}));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{}));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<cutlass::float_e4m3_t*>(WQ.data_ptr()),
stride_a,
reinterpret_cast<cutlass::float_e4m3_t*>(XQ.data_ptr()),
stride_b,
w_scale.data_ptr<float>(),
x_scale.data_ptr<float>(),
static_cast<uint8_t>(block_n / TB_M),
static_cast<uint8_t>(block_m / TB_N),
static_cast<uint8_t>(block_k / TB_K)},
{{},
(cutlass::bfloat16_t*)Y.data_ptr<at::BFloat16>(),
stride_output,
(cutlass::bfloat16_t*)Y.data_ptr<at::BFloat16>(),
stride_output},
};

Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

// FP8 blockwise Cutlass kernel dispatch.
at::Tensor dispatch_fp8_blockwise_kernel(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
KernelMode kernel = get_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>(
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>(
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
} else {
return f8f8bf16_blockwise_impl<128, 128, 128, 1, 2, 1>(
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
}
}

at::Tensor f8f8bf16_blockwise(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale, // FP32
at::Tensor w_scale, // FP32
int64_t block_m,
int64_t block_n,
int64_t block_k) {
// Check datatypes.
TORCH_CHECK(
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
"Scale tensors must be float32.");

return dispatch_fp8_blockwise_kernel(
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
}

template <
int TB_M,
int TB_N,
Expand Down Expand Up @@ -2198,6 +2438,17 @@ at::Tensor f8f8bf16_rowwise(
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
at::Tensor f8f8bf16_blockwise(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
#endif

at::Tensor i8i8bf16(
Expand Down
Loading

0 comments on commit 5a5b0e6

Please sign in to comment.