Skip to content

Commit

Permalink
Back out "Support FP8 scale calculation with scalar and cleanup"
Browse files Browse the repository at this point in the history
Summary:
There is garbage output issue because the new introduced Sm90ScalarBroadcast requires scalar scale that causes an issue under cudagraph. 

Reverting the changes in D57367680 resolves the issue. 

Will follow up on adding more unittests to cover similar issues and will test E2E before make those changes.

Differential Revision: D57521470
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed May 18, 2024
1 parent d92fb75 commit 6fcdd4b
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 135 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CKMatmul(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
out = torch.ops.fbgemm.f8f8bf16(a, b, scale)
out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale)
return out


Expand Down
204 changes: 199 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,200 @@ template <
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
at::Tensor f8f8bf16_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::float_e4m3_t;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::float_e4m3_t;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
ElementOutput>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)

using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = cutlass::gemm::collective::
KernelScheduleAuto; // Kernel to launch based on the default setting in
// the Collective Builder

using MainLoopSchedule = cute::conditional_t<
FAST_ACCUM,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementInputB,
LayoutInputB,
AlignmentInputB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAuto,
MainLoopSchedule>::CollectiveOp;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{}));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{}));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{}));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(XQ.data_ptr()),
stride_a},
{{scale.data_ptr<float>(), 0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output}};
Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

at::Tensor f8f8bf16(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale,
bool use_fast_accum) {
auto M = XQ.size(0);
// auto K = XQ.size(1);
// auto N = WQ.size(0);
if (use_fast_accum) {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, true>(XQ, WQ, scale);
}
} else {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, false>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false>(XQ, WQ, scale);
}
}
}

template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
at::Tensor f8f8bf16_tensorwise_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale) {
Expand Down Expand Up @@ -787,18 +978,21 @@ at::Tensor f8f8bf16_impl(
return Y;
}

at::Tensor f8f8bf16(
at::Tensor f8f8bf16_tensorwise(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale,
bool use_fast_accum) {
KernelMode kernel = get_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_impl<128, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>(
XQ, WQ, scale);
}
}

Expand Down
60 changes: 28 additions & 32 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale);

// Cutlass kernel
at::Tensor f8f8bf16(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor scale,
bool use_fast_accum = true);
at::Tensor f8f8bf16_tensorwise(
at::Tensor XQ,
at::Tensor WQ,
double scale,
Expand Down Expand Up @@ -62,12 +67,7 @@ at::Tensor f8i4bf16_rowwise(
at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale);
std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(at::Tensor X);

std::tuple<at::Tensor, double> quantize_fp8_per_tensor(
at::Tensor input,
c10::optional<at::Tensor> bs, // batch size
c10::optional<at::Tensor> scale_ub); // scale upperbound

std::tuple<at::Tensor, at::Tensor> quantize_fp8_per_tensor_tensor_scale(
std::vector<at::Tensor> quantize_fp8_per_tensor(
at::Tensor input,
c10::optional<at::Tensor> bs, // batch size
c10::optional<at::Tensor> scale_ub); // scale upperbound
Expand Down Expand Up @@ -102,6 +102,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// torch.ops.load_library, similar to below for quantize_fp8_per_tensor
m.def("i8i8bf16(Tensor XQ, Tensor WQ, float scale, int split_k=1) -> Tensor");

m.def(
"f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor");

m.def(
"f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor");

Expand All @@ -118,7 +121,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
#endif

m.def(
"f8f8bf16(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
"f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor");
m.impl("per_tensor_quantize_i8", per_tensor_quantize_i8);
m.def("per_tensor_dynamic_quantize_i8(Tensor X) -> (Tensor, Tensor)");
Expand All @@ -135,13 +138,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// quantize_ops with
// torch.ops.load_library
m.def(
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, float)");
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor[]");
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.def(
"quantize_fp8_per_tensor_tensor_scale(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, Tensor)");
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale);
m.def(
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]");
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
Expand All @@ -165,14 +163,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise);
#ifndef USE_ROCM
m.impl("i8i8bf16", i8i8bf16);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise);
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale);
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
#endif
}
Expand Down Expand Up @@ -201,22 +197,13 @@ at::Tensor f8f8bf16_rowwise_meta(
return Y;
}

std::tuple<at::Tensor, double> quantize_fp8_per_tensor_meta(
at::Tensor X,
c10::optional<at::Tensor> bs,
c10::optional<at::Tensor> scale_ub) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = 0.0;
return std::tuple<at::Tensor, double>{Y, scale};
}

std::tuple<at::Tensor, at::Tensor> quantize_fp8_per_tensor_tensor_scale_meta(
std::vector<at::Tensor> quantize_fp8_per_tensor_meta(
at::Tensor X,
c10::optional<at::Tensor> bs,
c10::optional<at::Tensor> scale_ub) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = at::empty({}, X.options().dtype(at::kBFloat16));
return std::tuple<at::Tensor, at::Tensor>{Y, scale};
return {Y, scale};
}

at::Tensor f8f8bf16_cublas_meta(
Expand All @@ -233,6 +220,17 @@ at::Tensor f8f8bf16_cublas_meta(
}

at::Tensor f8f8bf16_meta(
at::Tensor X,
at::Tensor W,
at::Tensor scale,
bool use_fast_accum = true) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor f8f8bf16_tensorwise_meta(
at::Tensor X,
at::Tensor W,
double scale,
Expand All @@ -256,14 +254,12 @@ at::Tensor f8i4bf16_rowwise_meta(
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta);
#ifndef USE_ROCM
m.impl("i8i8bf16", i8i8bf16_meta);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta);
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta);
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale_meta);
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta);
#endif
Expand Down
Loading

0 comments on commit 6fcdd4b

Please sign in to comment.