Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rahul quant merged #10341

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
da31648
semi_structured for fp16 and bf16 and int8
ilmarkov Oct 1, 2024
e655f94
Fix A100 int8 tests
ilmarkov Oct 2, 2024
5fc3c1c
Add fp8 cusparseLt
ilmarkov Oct 9, 2024
9cf36d6
wip
ilmarkov Oct 9, 2024
ad09e79
Fix signatures
ilmarkov Oct 9, 2024
e75eabc
Fix compilation and tests
ilmarkov Oct 13, 2024
0306390
Update for older platforms
ilmarkov Oct 15, 2024
1021acb
Add benchmarks
ilmarkov Oct 16, 2024
19ce358
Fix typo
ilmarkov Oct 23, 2024
959408c
Added scaled_mm for fp8.
ilmarkov Oct 24, 2024
117b87b
Add docstrings
ilmarkov Oct 28, 2024
2c7e68e
Update for torch 2.5
ilmarkov Oct 30, 2024
922f4f8
Add handling contiguous dense input for int8 and fp8
ilmarkov Oct 30, 2024
beca038
Add fp8 cusparseLt
ilmarkov Oct 9, 2024
5d9cd25
Fix compilation and tests
ilmarkov Oct 13, 2024
39ad9d4
Add caching of cusparseLT meta
ilmarkov Oct 23, 2024
520eb62
Cached cusparseLt
ilmarkov Oct 25, 2024
20956e6
Fix destroy function
ilmarkov Oct 25, 2024
87c8088
Prepare for reproduce
ilmarkov Oct 25, 2024
4ea58b1
Fix cusparseLt caching
ilmarkov Oct 30, 2024
f0551ef
Make cached version default function
ilmarkov Nov 5, 2024
d7476e8
Fixes and polishing after rebase
ilmarkov Nov 6, 2024
681ea5e
add sparse 2:4 weight loading suport
dsikka Oct 23, 2024
ecf878f
Some more changes!
rahul-tuli Oct 29, 2024
80952dc
Cleanup
rahul-tuli Oct 31, 2024
8462c9d
get uncompressed to work; update gemm to use contiguous; use alex's u…
dsikka Nov 1, 2024
0a3e506
patch
dsikka Nov 4, 2024
2e28972
use our decompressor
dsikka Nov 4, 2024
28f0abb
Some more work
rahul-tuli Nov 6, 2024
c7a97a8
Use new scaled_T function
rahul-tuli Nov 7, 2024
8700516
Remove q_input conversion to non-contiguous
rahul-tuli Nov 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/quantization/fp8_semi_structured/cusparseLt.cpp"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down Expand Up @@ -398,6 +399,14 @@ define_gpu_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

# If cuSparseLt is not installed we skip 2:4 optimizations
CHECK_INCLUDE_FILE_CXX("cusparseLt.h" HAVE_CUSPARSELT)
message(STATUS "Result of include cusparseLt ${HAVE_CUSPARSELT}")
target_compile_definitions(_C PRIVATE VLLM_CUSPARSELT_ENABLED=1)

# if(HAVE_CUSPARSELT)
# target_compile_definitions(_C PRIVATE VLLM_CUSPARSELT_ENABLED=1)
# endif()
#
# _moe_C extension
#
Expand Down
251 changes: 251 additions & 0 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import argparse
import copy
import itertools
from typing import Callable, Iterable, List, Tuple

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

from vllm.model_executor.layers.sparsity.utils.cusparse_2_4_utils import (
compress_to_torch_sparse_semi_structured_mat, dense_matmul, get_random_mat,
is_semi_structured_supported, semi_structured_sparse_dense_gemm)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]


# helpers
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = get_random_mat(m, k, dtype)
b = get_random_mat(n, k, dtype).t()
return a, b


# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1

globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)


def bench(m: int, k: int, n: int, label: str, sub_label: str,
use_fp8: bool) -> Iterable[TMeasurement]:
a, b = make_rand_tensors(torch.float16, m, n, k)

timers = []
# pytorch float16
timers.append(
bench_fn(label, sub_label, "pytorch_fp16_fp16_matmul", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))

# pytorch bf16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_matmul", torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))

# cusparseLt fp16
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp16_fp16_2_4",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a), b))

timers.append(
bench_fn(label,
sub_label,
"cusparseLt_fp16_fp16_2_4_noncached",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a),
b,
cached=False))

# cusparseLt bf16
a, b = make_rand_tensors(torch.bfloat16, m, n, k)
a_compressed = compress_to_torch_sparse_semi_structured_mat(a.to(dtype=torch.bfloat16))

timers.append(
bench_fn(label, sub_label, "cusparseLt_bf16_bf16_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

timers.append(
bench_fn(label,
sub_label,
"cusparseLt_bf16_bf16_2_4_noncached",
semi_structured_sparse_dense_gemm,
a_compressed,
b,
cached=False))

a, b = make_rand_tensors(torch.int8, m, n, k)
# cutlass i8
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_matmul", dense_matmul, a, b,
torch.int8))

# cusparseLt i8
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)
# warmup
semi_structured_sparse_dense_gemm(a_compressed, b)
timers.append(
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

timers.append(
bench_fn(label,
sub_label,
"cusparseLt_i8_i8_2_4_noncached",
semi_structured_sparse_dense_gemm,
a_compressed,
b,
cached=False))

if use_fp8:
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
# cutlass fp8
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul-w-scales",
dense_matmul, a, b, torch.float8_e4m3fn))

# cusparseLt fp8
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)

# warmup
semi_structured_sparse_dense_gemm(a_compressed, b)

timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

timers.append(
bench_fn(label,
sub_label,
"cusparseLt_fp8_fp8_2_4_noncached",
semi_structured_sparse_dense_gemm,
a_compressed,
b,
cached=False))

return timers


# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()


def run(MKNs: Iterable[Tuple[int, int, int]],
use_fp8: bool) -> Iterable[TMeasurement]:
results = []
# MKNs = [(1024, 8192, 14336)]
# MKNs = [(2048, 8192, 14336)]
# MKNs = [(2048, 8192, 14336), (2048, 8192, 14336)]
# MKNs = [(32, 11008, 4096)]
# MKNs = [(2048, 11008, 14336)]

for m, k, n in MKNs:
timers = bench(m, k, n, "gemm", f"MKN=({m}x{k}x{n})", use_fp8)
print_timers(timers)
results.extend(timers)

return results


def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)


def run_model_bench(args):
if not is_semi_structured_supported():
raise ValueError("Device does not support semi-structured sparsity")

print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs

model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
assert m % 32 == 0, "Batch size has to be a multiple of 32"
for k, n in KNs:
if k % 32 or n % 32:
continue
MKNs.append((m, k, n))

data = run(MKNs, args.use_fp8)
model_bench_data.append(data)

# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results cuSparseLt {model}-TP{tp_size} ====")
print_timers(data)


if __name__ == '__main__':

parser = FlexibleArgumentParser(
description="""
Benchmark cuSparseLt 2:4 GEMMs.

To run dimensions from a model:
python3 ./benchmarks/cusparseLt_benchmarks/benchmark_24.py --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1


Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cusparseLt implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument(
'--use-fp8',
action='store_true',
help='Add benchmarking fp8 matmul (on supporting fp8 platforms)')

args = parser.parse_args()
run_model_bench(args)
43 changes: 43 additions & 0 deletions benchmarks/cusparseLt_benchmarks/weight_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536

# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}
17 changes: 17 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,20 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif

#ifndef USE_ROCM
torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input);

torch::Tensor cslt_mm_semi_structured(
const torch::Tensor& compressed_A, const torch::Tensor& dense_B,
const c10::optional<double>& scale_opt,
const c10::optional<torch::Tensor>& bias_opt);

torch::Tensor cslt_mm_fp8_semi_structured2(
const torch::Tensor& compressed_A, const torch::Tensor& dense_B,
const c10::optional<double>& scale_opt,
const c10::optional<torch::Tensor>& bias_opt);

void cslt_clear_cache();

#endif
Loading