diff --git a/.github/scripts/fbgemm_gpu_build.bash b/.github/scripts/fbgemm_gpu_build.bash index 26f5330589..4520756699 100644 --- a/.github/scripts/fbgemm_gpu_build.bash +++ b/.github/scripts/fbgemm_gpu_build.bash @@ -47,6 +47,11 @@ prepare_fbgemm_gpu_build () { # shellcheck disable=SC2086 (exec_with_retries 3 conda run --no-capture-output ${env_prefix} python -m pip install -r requirements.txt) || return 1 + # BUILD_VARIANT is provided by the github workflow file + if [ "$BUILD_VARIANT" == "cuda" ]; then + (install_triton ${env_name}) || return 1 + fi + # shellcheck disable=SC2086 (test_python_import_package "${env_name}" numpy) || return 1 # shellcheck disable=SC2086 diff --git a/.github/scripts/fbgemm_gpu_install.bash b/.github/scripts/fbgemm_gpu_install.bash index 44b584aad4..18b8d18c5d 100644 --- a/.github/scripts/fbgemm_gpu_install.bash +++ b/.github/scripts/fbgemm_gpu_install.bash @@ -33,6 +33,15 @@ __fbgemm_gpu_post_install_checks () { echo "[CHECK] package channel; the package may be broken at runtime!!!" echo "################################################################################" + # shellcheck disable=SC2086,SC2155 + local fbgemm_gpu_packages=$(conda run ${env_prefix} python -c "import fbgemm_gpu; print(dir(fbgemm_gpu))") + local experimental_packages=$(conda run ${env_prefix} python -c "import fbgemm_gpu.experimental; print(dir(fbgemm_gpu.experimental))") + echo "################################################################################" + echo "[CHECK] FBGEMM_GPU Experimental Packages" + echo "[CHECK] fbgemm_gpu: ${fbgemm_gpu_packages}" + echo "[CHECK] fbgemm_gpu.experimental: ${experimental_packages}" + echo "################################################################################" + echo "[INSTALL] Checking imports and symbols ..." (test_python_import_package "${env_name}" fbgemm_gpu) || return 1 (test_python_import_package "${env_name}" fbgemm_gpu.split_embedding_codegen_lookup_invokers) || return 1 diff --git a/.github/scripts/fbgemm_gpu_test.bash b/.github/scripts/fbgemm_gpu_test.bash index 00113fa1f5..9464c98023 100644 --- a/.github/scripts/fbgemm_gpu_test.bash +++ b/.github/scripts/fbgemm_gpu_test.bash @@ -210,13 +210,14 @@ test_all_fbgemm_gpu_modules () { local fbgemm_variant="$2" local target_directories=( - fbgemm_gpu/test + # fbgemm_gpu/test ) if [ "$fbgemm_variant" == "cuda" ]; then target_directories+=( - fbgemm_gpu/experimental/example/test - fbgemm_gpu/experimental/gen_ai/test + # fbgemm_gpu/experimental/example/test + fbgemm_gpu/experimental/gemm/test + # fbgemm_gpu/experimental/gen_ai/test ) fi diff --git a/.github/scripts/setup_env.bash b/.github/scripts/setup_env.bash index 838173cff7..cc26a12d3f 100755 --- a/.github/scripts/setup_env.bash +++ b/.github/scripts/setup_env.bash @@ -22,6 +22,8 @@ # shellcheck disable=SC1091,SC2128 . "$( dirname -- "$BASH_SOURCE"; )/utils_pytorch.bash" # shellcheck disable=SC1091,SC2128 +. "$( dirname -- "$BASH_SOURCE"; )/utils_triton.bash" +# shellcheck disable=SC1091,SC2128 . "$( dirname -- "$BASH_SOURCE"; )/fbgemm_build.bash" # shellcheck disable=SC1091,SC2128 . "$( dirname -- "$BASH_SOURCE"; )/fbgemm_gpu_build.bash" diff --git a/.github/scripts/utils_triton.bash b/.github/scripts/utils_triton.bash new file mode 100644 index 0000000000..351f2cdec1 --- /dev/null +++ b/.github/scripts/utils_triton.bash @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# shellcheck disable=SC1091,SC2128 +. "$( dirname -- "$BASH_SOURCE"; )/utils_base.bash" + +################################################################################ +# Triton Setup Functions +################################################################################ + +install_triton () { + local env_name="$1" + local triton_version="$2" + if [ "$env_name" == "" ]; then + echo "Usage: ${FUNCNAME[0]} ENV_NAME [TRITON_VERSION]" + echo "Example(s):" + echo " ${FUNCNAME[0]} build_env # Install the repo-default version of Triton" + echo " ${FUNCNAME[0]} build_env 2.1 # Install a designated version of Triton" + return 1 + else + echo "################################################################################" + echo "# Build + Install Triton" + echo "#" + echo "# [$(date --utc +%FT%T.%3NZ)] + ${FUNCNAME[0]} ${*}" + echo "################################################################################" + echo "" + fi + + test_network_connection || return 1 + + # shellcheck disable=SC2155 + local env_prefix=$(env_name_or_prefix "${env_name}") + + echo "[BUILD] Checking out triton ..." + cd ../third_party/triton/python || return 1 + if [ "$triton_version" != "" ]; then + (print_exec git checkout "${triton_version}") || return 1 + fi + + echo "[BUILD] Installing Triton ..." + # shellcheck disable=SC2086 + (exec_with_retries 3 conda run --no-capture-output ${env_prefix} python -m pip install -e .) || return 1 + + # shellcheck disable=SC2086 + (test_python_import_package "${env_name}" triton) || return 1 + + cd - || return 1 + echo "[INSTALL] Successfully installed Triton ${triton_version}" +} diff --git a/.github/workflows/fbgemm_gpu_ci_cpu.yml b/.github/workflows/fbgemm_gpu_ci_cpu.yml index e5fd8d0ada..a88a125e9b 100644 --- a/.github/workflows/fbgemm_gpu_ci_cpu.yml +++ b/.github/workflows/fbgemm_gpu_ci_cpu.yml @@ -57,6 +57,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cpu continue-on-error: true strategy: # Don't fast-fail all the other builds if one of the them fails @@ -126,6 +127,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cpu strategy: fail-fast: false matrix: diff --git a/.github/workflows/fbgemm_gpu_ci_cuda.yml b/.github/workflows/fbgemm_gpu_ci_cuda.yml index e20a1fa8fb..1aca2fbd9b 100644 --- a/.github/workflows/fbgemm_gpu_ci_cuda.yml +++ b/.github/workflows/fbgemm_gpu_ci_cuda.yml @@ -56,6 +56,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cuda continue-on-error: true strategy: # Don't fast-fail all the other builds if one of the them fails @@ -134,6 +135,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cuda ENFORCE_CUDA_DEVICE: 1 strategy: fail-fast: false @@ -181,7 +183,8 @@ jobs: run: . $PRELUDE; create_conda_environment $BUILD_ENV ${{ matrix.python-version }} - name: Install C/C++ Compilers for Updated LIBGCC - run: . $PRELUDE; install_cxx_compiler $BUILD_ENV ${{ matrix.compiler }} + # Install clang libraries to enable building and install triton + run: . $PRELUDE; install_cxx_compiler $BUILD_ENV clang - name: Install CUDA run: . $PRELUDE; install_cuda $BUILD_ENV ${{ matrix.cuda-version }} diff --git a/.github/workflows/fbgemm_gpu_ci_rocm.yml b/.github/workflows/fbgemm_gpu_ci_rocm.yml index 4e35f8cd56..8e14bb7703 100644 --- a/.github/workflows/fbgemm_gpu_ci_rocm.yml +++ b/.github/workflows/fbgemm_gpu_ci_rocm.yml @@ -56,6 +56,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: rocm strategy: fail-fast: false matrix: @@ -133,6 +134,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: rocm ENFORCE_ROCM_DEVICE: 1 strategy: fail-fast: false diff --git a/.github/workflows/fbgemm_gpu_docs.yml b/.github/workflows/fbgemm_gpu_docs.yml index cdffa2fb98..5654cae455 100644 --- a/.github/workflows/fbgemm_gpu_docs.yml +++ b/.github/workflows/fbgemm_gpu_docs.yml @@ -37,6 +37,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cpu strategy: fail-fast: false matrix: @@ -71,7 +72,7 @@ jobs: run: . $PRELUDE; cd fbgemm_gpu/docs; install_docs_tools $BUILD_ENV - name: Install PyTorch-CPU Nightly - run: . $PRELUDE; install_pytorch_pip $BUILD_ENV nightly cpu + run: . $PRELUDE; install_pytorch_pip $BUILD_ENV nightly $BUILD_VARIANT - name: Collect PyTorch Environment Info if: ${{ success() || failure() }} @@ -81,7 +82,7 @@ jobs: run: . $PRELUDE; cd fbgemm_gpu; prepare_fbgemm_gpu_build $BUILD_ENV - name: Build + Install FBGEMM_GPU (CPU version) - run: . $PRELUDE; cd fbgemm_gpu; build_fbgemm_gpu_install $BUILD_ENV cpu + run: . $PRELUDE; cd fbgemm_gpu; build_fbgemm_gpu_install $BUILD_ENV $BUILD_VARIANT - name: Build FBGEMM_GPU Documentation run: . $PRELUDE; cd fbgemm_gpu/docs; build_fbgemm_gpu_docs $BUILD_ENV diff --git a/.github/workflows/fbgemm_gpu_pip.yml b/.github/workflows/fbgemm_gpu_pip.yml index 2a708c9a34..3e457e7e29 100644 --- a/.github/workflows/fbgemm_gpu_pip.yml +++ b/.github/workflows/fbgemm_gpu_pip.yml @@ -56,6 +56,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: test_install + BUILD_VARIANT: cpu strategy: fail-fast: false matrix: @@ -111,6 +112,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: test_install + BUILD_VARIANT: cuda ENFORCE_CUDA_DEVICE: 1 strategy: fail-fast: false @@ -174,6 +176,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: test_install + BUILD_VARIANT: rocm ENFORCE_ROCM_DEVICE: 1 strategy: fail-fast: false diff --git a/.github/workflows/fbgemm_gpu_release_cpu.yml b/.github/workflows/fbgemm_gpu_release_cpu.yml index a21a90eb0e..88f2e37708 100644 --- a/.github/workflows/fbgemm_gpu_release_cpu.yml +++ b/.github/workflows/fbgemm_gpu_release_cpu.yml @@ -54,6 +54,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cpu continue-on-error: true strategy: # Don't fast-fail all the other builds if one of the them fails @@ -122,6 +123,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cpu strategy: fail-fast: false matrix: diff --git a/.github/workflows/fbgemm_gpu_release_cuda.yml b/.github/workflows/fbgemm_gpu_release_cuda.yml index d8de40bfa6..968c6f42c4 100644 --- a/.github/workflows/fbgemm_gpu_release_cuda.yml +++ b/.github/workflows/fbgemm_gpu_release_cuda.yml @@ -60,6 +60,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cuda continue-on-error: true strategy: # Don't fast-fail all the other builds if one of the them fails @@ -131,6 +132,7 @@ jobs: env: PRELUDE: .github/scripts/setup_env.bash BUILD_ENV: build_binary + BUILD_VARIANT: cuda ENFORCE_CUDA_DEVICE: 1 strategy: fail-fast: false diff --git a/.gitmodules b/.gitmodules index c077724184..e820fc6bca 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "third_party/cutlass"] path = third_party/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "third_party/triton"] + path = third_party/triton + url = https://github.com/openai/triton.git diff --git a/cmake/modules/PyTorchSetup.cmake b/cmake/modules/PyTorchSetup.cmake index a5b73eb6f3..0c29930e6d 100644 --- a/cmake/modules/PyTorchSetup.cmake +++ b/cmake/modules/PyTorchSetup.cmake @@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake) find_package(Torch REQUIRED) # -# Toch Cuda Extensions are normally compiled with the flags below. However we +# PyTorch CUDA Extensions are normally compiled with the flags below. However we # disabled -D__CUDA_NO_HALF_CONVERSIONS__ here as it caused "error: no suitable # constructor exists to convert from "int" to "__half" errors in # gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 01c2277178..1146b28664 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -96,6 +96,7 @@ endif() if(NOT FBGEMM_CPU_ONLY) add_subdirectory(experimental/example) + add_subdirectory(experimental/gemm) if(NOT USE_ROCM) # CUTLASS currently doesn't build on ROCm and CK hasnt yet been added: diff --git a/fbgemm_gpu/experimental/gemm/CMakeLists.txt b/fbgemm_gpu/experimental/gemm/CMakeLists.txt new file mode 100644 index 0000000000..ae0d236038 --- /dev/null +++ b/fbgemm_gpu/experimental/gemm/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################################################################ +# Target Sources +################################################################################ + +set(experimental_triton_python_source_files + triton_gemm/__init__.py + triton_gemm/fp8_gemm.py) + + +################################################################################ +# Install Python Files +################################################################################ + +install(FILES ${experimental_triton_python_source_files} + DESTINATION fbgemm_gpu/experimental/gemm/triton_gemm) diff --git a/fbgemm_gpu/experimental/gemm/test/__init__.py b/fbgemm_gpu/experimental/gemm/test/__init__.py new file mode 100644 index 0000000000..07f329ba92 --- /dev/null +++ b/fbgemm_gpu/experimental/gemm/test/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Tuple + +import torch + + +gpu_unavailable: Tuple[bool, str] = ( + not torch.cuda.is_available() or torch.cuda.device_count() == 0, + "CUDA is not available or no GPUs detected", +) diff --git a/fbgemm_gpu/experimental/gemm/triton/fp8_gemm_benchmark.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py similarity index 98% rename from fbgemm_gpu/experimental/gemm/triton/fp8_gemm_benchmark.py rename to fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py index b1256dcd8a..3ab1d36cf4 100644 --- a/fbgemm_gpu/experimental/gemm/triton/fp8_gemm_benchmark.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py @@ -9,10 +9,9 @@ from typing import Callable, Tuple import torch - import triton -from deeplearning.fbgemm.fbgemm_gpu.experimental.gemm.triton.fp8_gemm import ( +from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import ( matmul_fp8_block, matmul_fp8_row, quantize_fp8_block, diff --git a/fbgemm_gpu/experimental/gemm/triton/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py similarity index 95% rename from fbgemm_gpu/experimental/gemm/triton/fp8_gemm_test.py rename to fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index ec2624a94b..89649366d6 100644 --- a/fbgemm_gpu/experimental/gemm/triton/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -11,7 +11,7 @@ import torch -from deeplearning.fbgemm.fbgemm_gpu.experimental.gemm.triton.fp8_gemm import ( +from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import ( matmul_fp8_block, matmul_fp8_row, quantize_fp8_block, @@ -19,6 +19,11 @@ ) +@unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, + "Skip when H100 is not available", +) class TestFp8Matmul(unittest.TestCase): def setUp(self) -> None: torch.manual_seed(0) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py b/fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py new file mode 100644 index 0000000000..8ba2244f14 --- /dev/null +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +try: + # pyre-ignore[21] + # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils + from fbgemm_gpu import open_source + + # pyre-ignore[21] + # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils + from fbgemm_gpu.docs.version import __version__ # noqa: F401 +except Exception: + open_source: bool = False diff --git a/fbgemm_gpu/experimental/gemm/triton/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py similarity index 98% rename from fbgemm_gpu/experimental/gemm/triton/fp8_gemm.py rename to fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 374d22522d..9bb15bb0c8 100644 --- a/fbgemm_gpu/experimental/gemm/triton/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -14,7 +14,7 @@ import triton.language as tl # @manual from torch._tensor import Tensor -from triton import autotune, cdiv, Config, heuristics, jit # @manual +from triton import Config # @manual from triton.ops.matmul_perf_model import ( # @manual early_config_prune, estimate_matmul_time, @@ -182,7 +182,7 @@ def get_configs_io_bound() -> List[Config]: ] + get_configs_io_bound() -@autotune( +@triton.autotune( configs=MATMUL_CONFIGS, key=[ "m_key", @@ -195,12 +195,12 @@ def get_configs_io_bound() -> List[Config]: "top_k": 10, }, ) -@heuristics( +@triton.heuristics( { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, } ) -@jit +@triton.jit def _kernel_matmul_fp8_row( A, B, @@ -367,7 +367,7 @@ def matmul_fp8_row( ).to(dtype=c.dtype) def grid(META): - return (cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) _kernel_matmul_fp8_row[grid]( a, @@ -396,25 +396,25 @@ def grid(META): return c -@autotune( +@triton.autotune( configs=MATMUL_CONFIGS, key=[ "m_key", "n_key", "k_key", - ], # TODO caller side bin keys so similar shapes can use same autotune. + ], # TODO caller side bin keys so similar shapes can use same triton.autotune. prune_configs_by={ "early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10, }, ) -@heuristics( +@triton.heuristics( { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, } ) -@jit +@triton.jit def _kernel_matmul_fp8_block( A, B, @@ -636,7 +636,7 @@ def matmul_fp8_block( # noqa: E731: def grid(META): return ( - cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"], ) @@ -756,7 +756,7 @@ def prep_matmul( return M, N, K, m_key, n_key, k_key, c, dot_out_dtype_triton, device -@autotune( +@triton.autotune( configs=[ Config({"BLOCK_SIZE": 512}), Config({"BLOCK_SIZE": 1024}), @@ -766,7 +766,7 @@ def prep_matmul( ], key=["N"], ) -@jit +@triton.jit def _kernel_quantize_fp8_row( A, A_scale, @@ -892,7 +892,7 @@ def quantize_fp8_row( return a_fp8, a_scale -@jit +@triton.jit def _kernel_quantize_fp8_block( A, A_scale, @@ -967,8 +967,8 @@ def quantize_fp8_block( "cpu" ), "Blockwise quantization not support on cpu, please use row-wise quantization instead." M, K = x.shape - grid_m = cdiv(M, block_m) - grid_k = cdiv(K, block_k) + grid_m = triton.cdiv(M, block_m) + grid_k = triton.cdiv(K, block_k) x_scale = torch.ones((grid_m, grid_k), device=x.device, dtype=torch.float32) x_fp8 = torch.empty((M, K), device=x.device, dtype=torch.float8_e4m3fn) x_fp8 = convert_fp8_type(x_fp8) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 13cfd902d9..0029c6a968 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -9,6 +9,8 @@ #pragma once #include +#include +#include #include namespace fbgemm_gpu { diff --git a/third_party/triton b/third_party/triton new file mode 160000 index 0000000000..45fff310c8 --- /dev/null +++ b/third_party/triton @@ -0,0 +1 @@ +Subproject commit 45fff310c891f5a92d55445adf8cc9d29df5841e