Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemm fp8 e4m3 #185

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bb89933
gemm fp8 e4m3
AndreSlavescu Aug 31, 2024
60f7ffd
update to benchmark
AndreSlavescu Aug 31, 2024
e11c22b
faster fwd performance with tl.multiple_of
AndreSlavescu Aug 31, 2024
e68f7f1
add stricter check for compute capability + exception handling
AndreSlavescu Aug 31, 2024
fafdfbe
Merge branch 'main' into matmulfp8
AndreSlavescu Aug 31, 2024
91bf3dd
perf improvement
AndreSlavescu Sep 3, 2024
9d467bf
remove discrete functional api
AndreSlavescu Sep 3, 2024
8b45800
make compute capability check a decorator
AndreSlavescu Sep 3, 2024
2319fc7
format
AndreSlavescu Sep 3, 2024
7418433
implement backward kernel as well
AndreSlavescu Sep 3, 2024
c39a7aa
add more benchmarks + diff utils
AndreSlavescu Sep 3, 2024
f8829e5
Merge branch 'main' of https://github.com/AndreSlavescu/Liger-Kernel …
AndreSlavescu Sep 4, 2024
032c4d9
update utils to include mma_v3 for H100
AndreSlavescu Sep 5, 2024
464cdd2
Merge branch 'main' into matmulfp8
lancerts Sep 6, 2024
c8dba40
update test.
AndreSlavescu Sep 7, 2024
ce2aee5
Merge branch 'matmulfp8' of https://github.com/AndreSlavescu/Liger-Ke…
AndreSlavescu Sep 7, 2024
cedf3de
Merge branch 'main' into matmulfp8
AndreSlavescu Sep 7, 2024
bb2f725
format
AndreSlavescu Sep 7, 2024
b3195da
Merge branch 'main' into matmulfp8
lancerts Sep 8, 2024
744642b
Merge branch 'main' into matmulfp8
lancerts Sep 8, 2024
7a8043a
Merge branch 'main' into matmulfp8
AndreSlavescu Sep 10, 2024
9e60b0a
compute types
AndreSlavescu Sep 12, 2024
98b7abf
modify benchmark to be up to date
AndreSlavescu Sep 12, 2024
a709616
format
AndreSlavescu Sep 12, 2024
c9cbc3a
Merge branch 'main' into matmulfp8
lancerts Sep 12, 2024
569b4eb
fix mem bounds
AndreSlavescu Sep 12, 2024
acc228b
Merge branch 'matmulfp8' of https://github.com/AndreSlavescu/Liger-Ke…
AndreSlavescu Sep 12, 2024
d31244a
docstring for fp8 gemm design
AndreSlavescu Sep 12, 2024
0f36098
format
AndreSlavescu Sep 12, 2024
edc4ebc
remove old benchmark format
AndreSlavescu Sep 12, 2024
618c858
Merge branch 'main' into matmulfp8
AndreSlavescu Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions benchmark/benchmark_gemm_split_k_fp8_e4m3.py
AndreSlavescu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os

import torch
import triton
from utils import _print_speed_banner, _test_memory, get_current_file_directory

from liger_kernel.ops.experimental.gemm_split_k_fp8_e4m3 import (
LigerFP8GemmSplitKFunction,
)

# enable TensorFloat32 tensor cores for better performance in benchmark
torch.set_float32_matmul_precision("high")


@triton.testing.perf_report(
[
triton.testing.Benchmark(
x_names=["m", "k", "n"],
x_vals=[
(64, 64, 64),
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(64, 128, 64),
(256, 512, 256),
(512, 1024, 512),
(1024, 2048, 1024),
],
xlabel="Matrix Size (m x k x n)",
line_arg="provider",
line_vals=["liger", "torch", "torch_compile"],
line_names=["Liger", "PyTorch", "Torch Compile"],
styles=[("blue", "solid"), ("orange", "solid"), ("green", "solid")],
ylabel="time (ms)",
plot_name="gemm-split-k-fp8-fwd-speed-benchmark",
args={"mode": "forward", "dtype": torch.float32},
),
triton.testing.Benchmark(
x_names=["m", "k", "n"],
x_vals=[
(64, 64, 64),
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(64, 128, 64),
(256, 512, 256),
(512, 1024, 512),
(1024, 2048, 1024),
],
xlabel="Matrix Size (m x k x n)",
line_arg="provider",
line_vals=["liger", "torch", "torch_compile"],
line_names=["Liger", "PyTorch", "Torch Compile"],
styles=[("blue", "solid"), ("orange", "solid"), ("green", "solid")],
ylabel="time (ms)",
plot_name="gemm-split-k-fp8-full-speed-benchmark",
args={"mode": "full", "dtype": torch.float32},
),
]
)
def bench_speed_gemm_split_k_fp8(m, k, n, provider, mode, dtype, device="cuda"):
a_fp8 = torch.randn((m, k), device=device, dtype=dtype).to(torch.float8_e4m3fn)
b_fp8 = torch.randn((k, n), device=device, dtype=dtype).to(torch.float8_e4m3fn)

a_float = a_fp8.float().requires_grad_()
b_float = b_fp8.float().requires_grad_()

def fwd_liger():
return LigerFP8GemmSplitKFunction.apply(a_fp8, b_fp8)

def fwd_torch():
return torch.matmul(a_float, b_float)
Copy link
Collaborator

@qingquansong qingquansong Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the speed/memory the fp8 kernel with torch matmul on fp32 is not a quite fair comparison. A better comparison would be to compare with torch._scaled_mm with fp8 matmul such as the example here: https://gist.github.com/malfet/7874d96b99670c3da83cbb779ab770c6


fwd_torch_compiled = torch.compile(fwd_torch)

if provider == "liger":
fwd_fn = fwd_liger
elif provider == "torch":
fwd_fn = fwd_torch
else:
fwd_fn = fwd_torch_compiled

quantiles = [0.5, 0.2, 0.8]

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(fwd_fn, quantiles=quantiles)
elif mode == "full":

def full():
y = fwd_fn()
if provider == "liger":
# compute manually gradients for Liger to avoid: "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
dc = torch.ones_like(y, dtype=torch.float8_e4m3fn)
LigerFP8GemmSplitKFunction.apply(dc, b_fp8.t())
LigerFP8GemmSplitKFunction.apply(a_fp8.t(), dc)
else:
torch.sum(y).backward()

ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=quantiles)

return ms, min_ms, max_ms


def benchmark_speed_gemm_split_k_fp8_wrapper():
_print_speed_banner()

curr_dir = get_current_file_directory()
dir_name = "gemm_split_k_fp8_speed"
output_dir = os.path.join(curr_dir, dir_name)
os.makedirs(output_dir, exist_ok=True)

bench_speed_gemm_split_k_fp8.run(save_path=output_dir, print_data=True)


@triton.testing.perf_report(
[
triton.testing.Benchmark(
x_names=["m", "k", "n"],
x_vals=[
(64, 64, 64),
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(64, 128, 64),
(256, 512, 256),
(512, 1024, 512),
(1024, 2048, 1024),
],
xlabel="Matrix Size (m x k x n)",
line_arg="provider",
line_vals=["liger", "torch", "torch_compile"],
line_names=["Liger", "PyTorch", "Torch Compile"],
styles=[("blue", "solid"), ("orange", "solid"), ("green", "solid")],
ylabel="GPU memory usage (MB)",
plot_name="gemm-split-k-fp8-memory-benchmark",
args={"dtype": torch.float32},
)
]
)
def bench_memory_gemm_split_k_fp8(m, k, n, provider, dtype, device="cuda"):
a_fp8 = torch.randn((m, k), device=device, dtype=dtype).to(torch.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Let's try to create bf16 input and compare the speed/memory of torch._scaled_mm v.s. the fp8 kernel and then compare the joint time of quant + dequant + matmul (with fp8 scaled factor) Thanks!

b_fp8 = torch.randn((k, n), device=device, dtype=dtype).to(torch.float8_e4m3fn)

a_float = a_fp8.float().requires_grad_()
b_float = b_fp8.float().requires_grad_()

def full_liger():
y = LigerFP8GemmSplitKFunction.apply(a_fp8, b_fp8)
# compute manually gradients for Liger to avoid: "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
dc = torch.ones_like(y, dtype=torch.float8_e4m3fn)
LigerFP8GemmSplitKFunction.apply(dc, b_fp8.t())
LigerFP8GemmSplitKFunction.apply(a_fp8.t(), dc)

def full_torch():
y = torch.matmul(a_float, b_float)
torch.sum(y).backward()

full_torch_compiled = torch.compile(full_torch)

if provider == "liger":
full_fn = full_liger
elif provider == "torch":
full_fn = full_torch
else:
full_fn = full_torch_compiled

mem = _test_memory(full_fn)
return mem / 2**20


def benchmark_memory_gemm_split_k_fp8_wrapper():
_print_speed_banner()

curr_dir = get_current_file_directory()
dir_name = "gemm_split_k_fp8_memory"
output_dir = os.path.join(curr_dir, dir_name)
os.makedirs(output_dir, exist_ok=True)

bench_memory_gemm_split_k_fp8.run(save_path=output_dir, print_data=True)


if __name__ == "__main__":
benchmark_speed_gemm_split_k_fp8_wrapper()
benchmark_memory_gemm_split_k_fp8_wrapper()
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
m,k,n,Liger,PyTorch,Torch Compile
64.000000,64.000000,64.000000,16.312500,16.380664,16.365039
256.000000,256.000000,256.000000,17.250000,18.325977,18.075977
512.000000,512.000000,512.000000,20.250000,24.550977,23.550977
1024.000000,1024.000000,1024.000000,32.250000,49.450977,45.450977
64.000000,128.000000,64.000000,16.363281,16.479102,16.463477
256.000000,512.000000,256.000000,18.062500,19.900977,19.650977
512.000000,1024.000000,512.000000,23.500000,30.850977,29.850977
1024.000000,2048.000000,1024.000000,45.250000,74.650977,70.650977
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions benchmark/gemm_split_k_fp8_memory/results.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<html><body>
<image src="gemm-split-k-fp8-memory-benchmark.png"/>
</body></html>
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
m,k,n,Liger,PyTorch,Torch Compile
64.000000,64.000000,64.000000,0.131072,0.150432,0.824320
256.000000,256.000000,256.000000,0.376832,0.429056,0.659328
512.000000,512.000000,512.000000,0.214016,0.405504,0.675872
1024.000000,1024.000000,1024.000000,0.302592,0.425984,0.985088
64.000000,128.000000,64.000000,0.368640,0.342528,0.586736
256.000000,512.000000,256.000000,0.241696,0.360512,0.943616
512.000000,1024.000000,512.000000,0.400384,0.513024,0.991744
1024.000000,2048.000000,1024.000000,0.272384,0.315392,0.983568
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there multiple lines for each configuration? should be 1 line per case?

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
m,k,n,Liger,PyTorch,Torch Compile
64.000000,64.000000,64.000000,0.011264,0.008192,0.007168
256.000000,256.000000,256.000000,0.013312,0.009216,0.009216
512.000000,512.000000,512.000000,0.016384,0.016384,0.016384
1024.000000,1024.000000,1024.000000,0.035840,0.038912,0.038912
64.000000,128.000000,64.000000,0.011264,0.007168,0.007168
256.000000,512.000000,256.000000,0.016384,0.013280,0.012288
512.000000,1024.000000,512.000000,0.022528,0.018944,0.019360
1024.000000,2048.000000,1024.000000,0.048128,0.062464,0.062464
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions benchmark/gemm_split_k_fp8_speed/results.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<html><body>
<image src="gemm-split-k-fp8-fwd-speed-benchmark.png"/>
<image src="gemm-split-k-fp8-full-speed-benchmark.png"/>
</body></html>
Loading
Loading