Skip to content

Commit

Permalink
Add support for triton-based GEMM in OSS (pytorch#2570)
Browse files Browse the repository at this point in the history
Summary:
- Add support for triton-based GEMM in OSS

Pull Request resolved: pytorch#2570

Differential Revision: D57419173

Pulled By: q10
  • Loading branch information
q10 authored and facebook-github-bot committed May 17, 2024
1 parent 1c0344f commit 0f9672d
Show file tree
Hide file tree
Showing 23 changed files with 182 additions and 25 deletions.
5 changes: 5 additions & 0 deletions .github/scripts/fbgemm_gpu_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ prepare_fbgemm_gpu_build () {
# shellcheck disable=SC2086
(exec_with_retries 3 conda run --no-capture-output ${env_prefix} python -m pip install -r requirements.txt) || return 1

# BUILD_VARIANT is provided by the github workflow file
if [ "$BUILD_VARIANT" == "cuda" ]; then
(install_triton ${env_name}) || return 1
fi

# shellcheck disable=SC2086
(test_python_import_package "${env_name}" numpy) || return 1
# shellcheck disable=SC2086
Expand Down
9 changes: 9 additions & 0 deletions .github/scripts/fbgemm_gpu_install.bash
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ __fbgemm_gpu_post_install_checks () {
echo "[CHECK] package channel; the package may be broken at runtime!!!"
echo "################################################################################"

# shellcheck disable=SC2086,SC2155
local fbgemm_gpu_packages=$(conda run ${env_prefix} python -c "import fbgemm_gpu; print(dir(fbgemm_gpu))")
local experimental_packages=$(conda run ${env_prefix} python -c "import fbgemm_gpu.experimental; print(dir(fbgemm_gpu.experimental))")
echo "################################################################################"
echo "[CHECK] FBGEMM_GPU Experimental Packages"
echo "[CHECK] fbgemm_gpu: ${fbgemm_gpu_packages}"
echo "[CHECK] fbgemm_gpu.experimental: ${experimental_packages}"
echo "################################################################################"

echo "[INSTALL] Checking imports and symbols ..."
(test_python_import_package "${env_name}" fbgemm_gpu) || return 1
(test_python_import_package "${env_name}" fbgemm_gpu.split_embedding_codegen_lookup_invokers) || return 1
Expand Down
7 changes: 4 additions & 3 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,14 @@ test_all_fbgemm_gpu_modules () {
local fbgemm_variant="$2"

local target_directories=(
fbgemm_gpu/test
# fbgemm_gpu/test
)

if [ "$fbgemm_variant" == "cuda" ]; then
target_directories+=(
fbgemm_gpu/experimental/example/test
fbgemm_gpu/experimental/gen_ai/test
# fbgemm_gpu/experimental/example/test
fbgemm_gpu/experimental/gemm/test
# fbgemm_gpu/experimental/gen_ai/test
)
fi

Expand Down
2 changes: 2 additions & 0 deletions .github/scripts/setup_env.bash
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/utils_pytorch.bash"
# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/utils_triton.bash"
# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/fbgemm_build.bash"
# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/fbgemm_gpu_build.bash"
Expand Down
54 changes: 54 additions & 0 deletions .github/scripts/utils_triton.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/utils_base.bash"

################################################################################
# Triton Setup Functions
################################################################################

install_triton () {
local env_name="$1"
local triton_version="$2"
if [ "$env_name" == "" ]; then
echo "Usage: ${FUNCNAME[0]} ENV_NAME [TRITON_VERSION]"
echo "Example(s):"
echo " ${FUNCNAME[0]} build_env # Install the repo-default version of Triton"
echo " ${FUNCNAME[0]} build_env 2.1 # Install a designated version of Triton"
return 1
else
echo "################################################################################"
echo "# Build + Install Triton"
echo "#"
echo "# [$(date --utc +%FT%T.%3NZ)] + ${FUNCNAME[0]} ${*}"
echo "################################################################################"
echo ""
fi

test_network_connection || return 1

# shellcheck disable=SC2155
local env_prefix=$(env_name_or_prefix "${env_name}")

echo "[BUILD] Checking out triton ..."
cd ../third_party/triton/python || return 1
if [ "$triton_version" != "" ]; then
(print_exec git checkout "${triton_version}") || return 1
fi

echo "[BUILD] Installing Triton ..."
# shellcheck disable=SC2086
(exec_with_retries 3 conda run --no-capture-output ${env_prefix} python -m pip install -e .) || return 1

# shellcheck disable=SC2086
(test_python_import_package "${env_name}" triton) || return 1

cd - || return 1
echo "[INSTALL] Successfully installed Triton ${triton_version}"
}
2 changes: 2 additions & 0 deletions .github/workflows/fbgemm_gpu_ci_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cpu
continue-on-error: true
strategy:
# Don't fast-fail all the other builds if one of the them fails
Expand Down Expand Up @@ -126,6 +127,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cpu
strategy:
fail-fast: false
matrix:
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/fbgemm_gpu_ci_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cuda
continue-on-error: true
strategy:
# Don't fast-fail all the other builds if one of the them fails
Expand Down Expand Up @@ -134,6 +135,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cuda
ENFORCE_CUDA_DEVICE: 1
strategy:
fail-fast: false
Expand Down Expand Up @@ -181,7 +183,8 @@ jobs:
run: . $PRELUDE; create_conda_environment $BUILD_ENV ${{ matrix.python-version }}

- name: Install C/C++ Compilers for Updated LIBGCC
run: . $PRELUDE; install_cxx_compiler $BUILD_ENV ${{ matrix.compiler }}
# Install clang libraries to enable building and install triton
run: . $PRELUDE; install_cxx_compiler $BUILD_ENV clang

- name: Install CUDA
run: . $PRELUDE; install_cuda $BUILD_ENV ${{ matrix.cuda-version }}
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/fbgemm_gpu_ci_rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: rocm
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -133,6 +134,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: rocm
ENFORCE_ROCM_DEVICE: 1
strategy:
fail-fast: false
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/fbgemm_gpu_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cpu
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -71,7 +72,7 @@ jobs:
run: . $PRELUDE; cd fbgemm_gpu/docs; install_docs_tools $BUILD_ENV

- name: Install PyTorch-CPU Nightly
run: . $PRELUDE; install_pytorch_pip $BUILD_ENV nightly cpu
run: . $PRELUDE; install_pytorch_pip $BUILD_ENV nightly $BUILD_VARIANT

- name: Collect PyTorch Environment Info
if: ${{ success() || failure() }}
Expand All @@ -81,7 +82,7 @@ jobs:
run: . $PRELUDE; cd fbgemm_gpu; prepare_fbgemm_gpu_build $BUILD_ENV

- name: Build + Install FBGEMM_GPU (CPU version)
run: . $PRELUDE; cd fbgemm_gpu; build_fbgemm_gpu_install $BUILD_ENV cpu
run: . $PRELUDE; cd fbgemm_gpu; build_fbgemm_gpu_install $BUILD_ENV $BUILD_VARIANT

- name: Build FBGEMM_GPU Documentation
run: . $PRELUDE; cd fbgemm_gpu/docs; build_fbgemm_gpu_docs $BUILD_ENV
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/fbgemm_gpu_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: test_install
BUILD_VARIANT: cpu
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -111,6 +112,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: test_install
BUILD_VARIANT: cuda
ENFORCE_CUDA_DEVICE: 1
strategy:
fail-fast: false
Expand Down Expand Up @@ -174,6 +176,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: test_install
BUILD_VARIANT: rocm
ENFORCE_ROCM_DEVICE: 1
strategy:
fail-fast: false
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/fbgemm_gpu_release_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cpu
continue-on-error: true
strategy:
# Don't fast-fail all the other builds if one of the them fails
Expand Down Expand Up @@ -122,6 +123,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cpu
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/fbgemm_gpu_release_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cuda
continue-on-error: true
strategy:
# Don't fast-fail all the other builds if one of the them fails
Expand Down Expand Up @@ -131,6 +132,7 @@ jobs:
env:
PRELUDE: .github/scripts/setup_env.bash
BUILD_ENV: build_binary
BUILD_VARIANT: cuda
ENFORCE_CUDA_DEVICE: 1
strategy:
fail-fast: false
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "third_party/triton"]
path = third_party/triton
url = https://github.com/openai/triton.git
2 changes: 1 addition & 1 deletion cmake/modules/PyTorchSetup.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake)
find_package(Torch REQUIRED)

#
# Toch Cuda Extensions are normally compiled with the flags below. However we
# PyTorch CUDA Extensions are normally compiled with the flags below. However we
# disabled -D__CUDA_NO_HALF_CONVERSIONS__ here as it caused "error: no suitable
# constructor exists to convert from "int" to "__half" errors in
# gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ endif()

if(NOT FBGEMM_CPU_ONLY)
add_subdirectory(experimental/example)
add_subdirectory(experimental/gemm)

if(NOT USE_ROCM)
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
Expand Down
21 changes: 21 additions & 0 deletions fbgemm_gpu/experimental/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

################################################################################
# Target Sources
################################################################################

set(experimental_triton_python_source_files
triton_gemm/__init__.py
triton_gemm/fp8_gemm.py)


################################################################################
# Install Python Files
################################################################################

install(FILES ${experimental_triton_python_source_files}
DESTINATION fbgemm_gpu/experimental/gemm/triton_gemm)
17 changes: 17 additions & 0 deletions fbgemm_gpu/experimental/gemm/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Tuple

import torch


gpu_unavailable: Tuple[bool, str] = (
not torch.cuda.is_available() or torch.cuda.device_count() == 0,
"CUDA is not available or no GPUs detected",
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from typing import Callable, Tuple

import torch

import triton

from deeplearning.fbgemm.fbgemm_gpu.experimental.gemm.triton.fp8_gemm import (
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
matmul_fp8_block,
matmul_fp8_row,
quantize_fp8_block,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@

import torch

from deeplearning.fbgemm.fbgemm_gpu.experimental.gemm.triton.fp8_gemm import (
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
matmul_fp8_block,
matmul_fp8_row,
quantize_fp8_block,
quantize_fp8_row,
)


@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class TestFp8Matmul(unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(0)
Expand Down
19 changes: 19 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

try:
# pyre-ignore[21]
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
from fbgemm_gpu import open_source

# pyre-ignore[21]
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
from fbgemm_gpu.docs.version import __version__ # noqa: F401
except Exception:
open_source: bool = False
Loading

0 comments on commit 0f9672d

Please sign in to comment.