Skip to content

Commit

Permalink
Back out "Support FP8 scale calculation with scalar and cleanup" (#2604)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2604

There is garbage output issue because the new introduced Sm90ScalarBroadcast requires scalar scale that causes an issue under cudagraph.

Reverting the changes in D57367680 resolves the issue.

Will follow up on adding more unittests to cover similar issues and will test E2E before make those changes.

Reviewed By: jianyuh

Differential Revision: D57521470

fbshipit-source-id: 685054db8248ff6ef6e029451c6563ba8f37ba29
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed May 18, 2024
1 parent 19a91b9 commit 79aba2c
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 143 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CKMatmul(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
out = torch.ops.fbgemm.f8f8bf16(a, b, scale)
out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale)
return out


Expand Down
16 changes: 8 additions & 8 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions.hip
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ template <
int MPER_WAVE,
int NPER_WAVE,
bool PADDING = false>
at::Tensor f8f8bf16_impl(at::Tensor XQ, at::Tensor WQ, double scale) {
at::Tensor f8f8bf16_tensorwise_impl(at::Tensor XQ, at::Tensor WQ, double scale) {
// Get input information.
int M = XQ.size(0);
int N = WQ.size(0);
Expand Down Expand Up @@ -212,24 +212,24 @@ std::tuple<KernelMode, bool> get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
}

at::Tensor
f8f8bf16(at::Tensor XQ, at::Tensor WQ, double scale, bool use_fast_accum) {
f8f8bf16_tensorwise(at::Tensor XQ, at::Tensor WQ, double scale, bool use_fast_accum) {
TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum");
auto [kernel, pad] = get_kernel_mode(XQ, WQ);
if (pad) {
if (kernel == KernelMode::Small) {
return f8f8bf16_impl<64, 32, 64, 64, 1, 2, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<64, 32, 64, 64, 1, 2, true>(XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_impl<256, 256, 128, 64, 4, 2, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<256, 256, 128, 64, 4, 2, true>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<256, 128, 128, 64, 2, 2, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<256, 128, 128, 64, 2, 2, true>(XQ, WQ, scale);
}
} else {
if (kernel == KernelMode::Small) {
return f8f8bf16_impl<64, 32, 64, 64, 1, 2>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<64, 32, 64, 64, 1, 2>(XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_impl<256, 256, 128, 64, 4, 2>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<256, 256, 128, 64, 4, 2>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<256, 128, 128, 64, 2, 2>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<256, 128, 128, 64, 2, 2>(XQ, WQ, scale);
}
}
}
Expand Down
204 changes: 199 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,200 @@ template <
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
at::Tensor f8f8bf16_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::float_e4m3_t;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::float_e4m3_t;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
ElementOutput>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)

using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = cutlass::gemm::collective::
KernelScheduleAuto; // Kernel to launch based on the default setting in
// the Collective Builder

using MainLoopSchedule = cute::conditional_t<
FAST_ACCUM,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementInputB,
LayoutInputB,
AlignmentInputB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAuto,
MainLoopSchedule>::CollectiveOp;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{}));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{}));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{}));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(XQ.data_ptr()),
stride_a},
{{scale.data_ptr<float>(), 0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output}};
Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

at::Tensor f8f8bf16(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale,
bool use_fast_accum) {
auto M = XQ.size(0);
// auto K = XQ.size(1);
// auto N = WQ.size(0);
if (use_fast_accum) {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, true>(XQ, WQ, scale);
}
} else {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, false>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false>(XQ, WQ, scale);
}
}
}

template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
at::Tensor f8f8bf16_tensorwise_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale) {
Expand Down Expand Up @@ -787,18 +978,21 @@ at::Tensor f8f8bf16_impl(
return Y;
}

at::Tensor f8f8bf16(
at::Tensor f8f8bf16_tensorwise(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale,
bool use_fast_accum) {
KernelMode kernel = get_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_impl<128, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false, true>(XQ, WQ, scale);
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>(
XQ, WQ, scale);
}
}

Expand Down
Loading

0 comments on commit 79aba2c

Please sign in to comment.