Skip to content

Commit

Permalink
2024-09-19 nightly release (46e309d)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 19, 2024
1 parent 79ef064 commit 51f62a1
Show file tree
Hide file tree
Showing 20 changed files with 569 additions and 172 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
from torch import Tensor

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

try:
# pyre-ignore[21]
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import torch
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
# pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`.
from torch.profiler import profile, ProfilerActivity

logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/bench/quantize_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
# pyre-ignore[21]
from torch.profiler import profile, ProfilerActivity


logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/split_embeddings_cache_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

from torch import nn, Tensor

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

try:
# pyre-ignore[21]
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
from torch import Tensor
from torch.profiler import profile

logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

haveAIBench = False
try:
from aibench_observer.utils.observer import emitMetric
Expand Down
5 changes: 2 additions & 3 deletions fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@
from torch.autograd.profiler import record_function
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

load_torch_module(
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings",
)

logging.basicConfig(level=logging.DEBUG)


@click.group()
def cli() -> None:
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/stride_gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

try:
# pyre-ignore[21]
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ sphinx<7

breathe
bs4
docutils
docutils<0.20,>=0.18.1
lxml
myst-parser
sphinx-lint
Expand Down
7 changes: 5 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
at::Tensor cache_V,
at::Tensor kv_seqlen,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);
std::optional<at::Tensor> qparam_v,
std::optional<at::Tensor> block_tables,
int64_t page_size);

at::Tensor mqa_attn(
at::Tensor XQ,
Expand Down Expand Up @@ -162,7 +164,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)");
m.impl("dequantize_int4_cache", dequantize_int4_cache);
m.def(
"dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None) -> (Tensor, Tensor)");
"dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)");
m.impl("dequantize_fp8_cache", dequantize_fp8_cache);
}

Expand Down
204 changes: 180 additions & 24 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,27 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
} else {
__half2* qparam_row = nullptr;
auto T = cache_K.size(1);
auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
if (qkv == QKV::K) {
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
if (block_tables == nullptr) {
auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
if (qkv == QKV::K) {
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
} else {
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
} else {
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
// This is duplicate computation with get_dst_row above.
// TODO: Maybe clean up and merge later.
int page_logical_idx = cache_loc_t / page_size;
int page_offset = cache_loc_t % page_size;
int page_physical_idx =
block_tables[b * block_tables_b_stride + page_logical_idx];
int physical_t = page_physical_idx * page_size + page_offset;
auto idx = physical_t * N_KVH + h;
if (qkv == QKV::K) {
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
} else {
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
}
quantize_fp8_kv(dst, dst_row_q, qparam_row);
}
Expand Down Expand Up @@ -1477,16 +1493,113 @@ __global__ void dequantize_fp8_cache_kernel(
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
}
}

// Cloned from dequantize_fp8_cache_kernel because
// branching inside the original kernel runs into
// "too many resources requested for launch" which
// necessitates decreasing the number of warps per block,
// which might have performance implications. Also we
// might have more diverging behaviors for paged kernel
// as noted in the comment below so we will keep a separate
// kernel for now.
__global__ void dequantize_fp8_cache_kernel_paged(
// This code currently represents FP8 version not int4
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H]
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G]
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
cache_K_dq, // [1][MAX_T][N_KVH][D_H]
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
cache_V_dq, // [1][MAX_T][N_KVH][D_H]
int32_t* qparam_k_ptr,
int32_t* qparam_v_ptr,
int32_t* block_tables,
int32_t block_tables_b_stride,
int32_t page_size) {
auto N_KVH = cache_K.size(2);
auto MAX_T = cache_K.size(1);
auto D_H = cache_K_dq.size(3);
auto D_H_q = cache_K.size(3);
CUDA_KERNEL_ASSERT(D_H == 128);

auto b = blockIdx.x;
// only need to dequantize this far.
auto max_t = kv_seqlen[b];

// one warp per T/H
for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
t_h += blockDim.y * gridDim.y) {
auto h = t_h % N_KVH;
auto t = t_h / N_KVH;

int page_logical_idx = t / page_size;
int page_offset = t % page_size;
int page_physical_idx =
block_tables[b * block_tables_b_stride + page_logical_idx];
int physical_t = page_physical_idx * page_size + page_offset;

uint8_t* row_k = &cache_K[0][physical_t][h][0];
uint8_t* row_v = &cache_V[0][physical_t][h][0];

bfx8 kv_dq;
uint8_t qparam_offset_bytes;
__half2* qparam_k_src;
__half2* qparam_v_src;
if (qparam_k_ptr) {
// read from standalone qparam tensor
qparam_offset_bytes = 0;
auto idx = physical_t * N_KVH + h;
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
} else {
// read from first row
qparam_offset_bytes = 4;
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
}
// Assert the quantized row dim is as expected
CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes);
if (4 * threadIdx.x >= D_H) {
continue;
}
// each thread reads 4 x 8 bits

uint64_t kq = *reinterpret_cast<uint32_t*>(
&row_k[threadIdx.x * 4 + qparam_offset_bytes]);
uint64_t vq = *reinterpret_cast<uint32_t*>(
&row_v[threadIdx.x * 4 + qparam_offset_bytes]);

uint64_t packed = kq | (vq << 32);

kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);

// now, write our outputs
auto* row_k_dq = &cache_K_dq[0][physical_t][h][0];
auto* row_v_dq = &cache_V_dq[0][physical_t][h][0];
// each thread writes 4 elements of type bf16
*reinterpret_cast<uint2*>(&row_k_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[0]);
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
}
}
std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor kv_seqlen,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v) {
std::optional<at::Tensor> qparam_v,
std::optional<at::Tensor> block_tables,
int64_t page_size) {
TORCH_CHECK(cache_K.is_cuda());
TORCH_CHECK(cache_V.is_cuda());
TORCH_CHECK(kv_seqlen.is_cuda());
auto B = cache_K.size(0);
auto B = kv_seqlen.size(0);
// vanilla: B_KV = B, paged: B_KV = 1
auto B_KV = cache_K.size(0);
// vanilla: MAX_T = MAX_T, paged: MAX_T = MAX_PAGE * PAGE_SIZE
auto MAX_T = cache_K.size(1);
auto N_KVH = cache_K.size(2);
auto D_HQ = cache_K.size(3);
Expand All @@ -1500,31 +1613,72 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
}
auto D_H = (D_HQ - fp8_qparam_offset);

auto cache_K_dq =
at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq =
at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
// TODO:
// The below allocates Tensors that have the same shape as cache_K and cache_V
// to store their dequantize results. For paged KV cache, this can be a bit
// inefficient because it has the shape of [1 x (MAX_PAGES * PAGE_SIZE) x
// N_KVH x D_H] to accommodate pages globally across batch instances, and
// if we have very large MAX_PAGES then we are essentially allocating a very
// huge Tensor here. The benefit is that the following users of this
// dequantized results can reuse the existing block_tables to access their
// elements. If we want to be more efficient, there are two possible
// approaches: (1) Allocate a shorter Tensor here and store the dequantize
// results in a more compact manner, but that requires creating a new
// block_tables here and making sure the following users all use the
// correct block_tables. (2) From outside, keep a persistent buffer that has a
// matching shape with the original paged KV and feed the same buffer
// into this function at every layer to reuse it and prevent allocation.
auto cache_K_dq = at::empty(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq = at::empty(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));

if (B == 0) {
return {cache_K_dq, cache_V_dq};
}

int32_t* block_tables_ptr = nullptr;
int32_t block_tables_b_stride = 0;
if (block_tables.has_value()) {
block_tables_ptr = static_cast<int32_t*>(block_tables.value().data_ptr());
block_tables_b_stride = block_tables.value().stride(0);
}

constexpr int32_t kMaxBlocks = 256;
dim3 blocks(B, std::max<int32_t>(1, kMaxBlocks / B));
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
dequantize_fp8_cache_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
if (block_tables_ptr == nullptr) {
dequantize_fp8_cache_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dequantize_fp8_cache_kernel_paged<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr,
block_tables_ptr,
block_tables_b_stride,
page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
return {cache_K_dq, cache_V_dq};
}
Expand Down Expand Up @@ -1606,7 +1760,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
at::Tensor cache_V,
at::Tensor kv_seqlen,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v) {
std::optional<at::Tensor> qparam_v,
std::optional<at::Tensor> block_tables,
int64_t page_size) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
Loading

0 comments on commit 51f62a1

Please sign in to comment.