Skip to content

Commit

Permalink
Integrate CUTLASS build into FBGEMM_GPU OSS (pytorch#2537)
Browse files Browse the repository at this point in the history
Summary:
- Integrate CUTLASS building capability into FBGEMM_GPU OSS
- Add example CUTLASS-based operator for testing

Pull Request resolved: pytorch#2537

Reviewed By: jianyuh

Differential Revision: D56735138

Pulled By: q10

fbshipit-source-id: 9a190121383995af460f22db5cae80223ea79881
  • Loading branch information
q10 authored and facebook-github-bot committed Apr 30, 2024
1 parent c21546c commit ca4e84b
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ test_all_fbgemm_gpu_modules () {

local target_directories=(
fbgemm_gpu/test
fbgemm_gpu/experimental/example/test
)

if [ "$fbgemm_variant" == "cuda" ]; then
target_directories+=(
fbgemm_gpu/experimental/example/test
fbgemm_gpu/experimental/gen_ai/test
)
fi
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ if(NOT TARGET asmjit)
target_compile_options(asmjit PRIVATE "-Wno-sign-conversion")
endif()

target_compile_options_if_supported(asmjit -Wno-deprecated-enum-enum-conversion)
target_compile_options_if_supported(asmjit -Wno-deprecated-anon-enum-enum-conversion)
target_compile_options_if_supported(asmjit -Wno-error=deprecated-enum-enum-conversion)
target_compile_options_if_supported(asmjit -Wno-deprecated-enum-enum-conversion)

if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0)
# See https://github.com/pytorch/pytorch/issues/74352, https://github.com/pytorch/FBGEMM/issues/1173
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ include(FbgemmGpu.cmake)
# Build Experimental Modules
################################################################################

add_subdirectory(experimental/example)


if(NOT FBGEMM_CPU_ONLY)
add_subdirectory(experimental/example)
add_subdirectory(experimental/gen_ai)
endif()
41 changes: 27 additions & 14 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,36 @@ include(${CMAKEMODULES}/Utilities.cmake)
set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen)


################################################################################
# Source Includes
################################################################################

set(fbgemm_sources_include_directories
# FBGEMM
${FBGEMM}/include
# FBGEMM_GPU
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/../include
# Third-party
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include)


################################################################################
# Third Party Sources
################################################################################

file(GLOB_RECURSE asmjit_sources
"${CMAKE_CURRENT_SOURCE_DIR}/../third_party/asmjit/src/asmjit/*/*.cpp")

set(third_party_include_directories
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include)


################################################################################
# Optimizer Group Definitions
################################################################################
Expand Down Expand Up @@ -256,26 +279,23 @@ endif()

set_source_files_properties(${gen_cpu_source_files}
PROPERTIES INCLUDE_DIRECTORIES
"${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include;${THIRDPARTY}/asmjit/src"
)
"${fbgemm_sources_include_directories}")

set_source_files_properties(${gen_gpu_host_source_files}
PROPERTIES INCLUDE_DIRECTORIES
"${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include"
)
"${fbgemm_sources_include_directories}")

set_source_files_properties(${gen_gpu_kernel_source_files}
PROPERTIES INCLUDE_DIRECTORIES
"${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include")
"${fbgemm_sources_include_directories}")

set_source_files_properties(${gen_gpu_kernel_source_files}
PROPERTIES COMPILE_OPTIONS
"${TORCH_CUDA_OPTIONS}")

set_source_files_properties(${gen_defused_optim_source_files}
PROPERTIES INCLUDE_DIRECTORIES
"${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include"
)
"${fbgemm_sources_include_directories}")

if(NOT FBGEMM_CPU_ONLY)
set(fbgemm_gpu_sources_gen
Expand Down Expand Up @@ -340,13 +360,6 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNO_AVX512=1")
endif()

set(fbgemm_sources_include_directories
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${FBGEMM}/include
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include)

set_source_files_properties(${fbgemm_sources}
PROPERTIES INCLUDE_DIRECTORIES
"${fbgemm_sources_include_directories}")
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/experimental/example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ include(${CMAKEMODULES}/Utilities.cmake)
################################################################################

set(experimental_example_cpp_source_files
src/cutlass_sgemm_nn.cu
src/example_ops.cpp)

set_source_files_properties(${experimental_example_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES
"${fbgemm_sources_include_directories}")

set(experimental_example_python_source_files
example/__init__.py
example/utils.py)
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/example/example/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
)
else:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/example:example_ops_cpu"
"//deeplearning/fbgemm/fbgemm_gpu/experimental/example:example_ops_cuda"
)
6 changes: 6 additions & 0 deletions fbgemm_gpu/experimental/example/example/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@

def add_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.ops.fbgemm.add_tensors_float(a, b)


def sgemm(
alpha: float, TA: torch.Tensor, TB: torch.Tensor, beta: float, TC: torch.Tensor
) -> torch.Tensor:
return torch.ops.fbgemm.sgemm_float(alpha, TA, TB, beta, TC)
98 changes: 98 additions & 0 deletions fbgemm_gpu/experimental/example/src/cutlass_sgemm_nn.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <cutlass/gemm/device/gemm.h>
#include <torch/library.h>

namespace fbgemm_gpu::experimental {

at::Tensor sgemm_float_cuda(
const double alpha_,
const at::Tensor& TA,
const at::Tensor& TB,
const double beta_,
const at::Tensor& TC) {
TORCH_CHECK_EQ(TA.dim(), 2);
TORCH_CHECK_EQ(TB.dim(), 2);
TORCH_CHECK_EQ(TC.dim(), 2);

const auto M = static_cast<int>(TA.size(0));
const auto K = static_cast<int>(TA.size(1));
const auto N = static_cast<int>(TB.size(1));

TORCH_CHECK_EQ(TB.size(0), K);
TORCH_CHECK_EQ(TC.size(0), M);
TORCH_CHECK_EQ(TC.size(1), N);

// Compute leading dimensions for each matrix
const auto lda = K;
const auto ldb = N;
const auto ldc = N;

const auto* A = TA.data_ptr<float>();
const auto* B = TB.data_ptr<float>();
const auto* C = TC.data_ptr<float>();

const auto alpha = static_cast<float>(alpha_);
const auto beta = static_cast<float>(beta_);

// Create result tensor
auto TD = at::zeros({M, N}, TC.options());
auto* D = TD.data_ptr<float>();

// PyTorch tensors are stored in row-major format
using Layout = cutlass::layout::RowMajor;

// Define type definition for single-precision CUTLASS GEMM with row-major
// input matrices and 128x128x8 threadblock tile size (chosen by default)
using CutlassGemm = cutlass::gemm::device::Gemm<
float, // Data-type of A matrix
Layout, // Layout of A matrix
float, // Data-type of B matrix
Layout, // Layout of B matrix
float, // Data-type of C matrix
Layout>; // Layout of C matrix

// Construct the CUTLASS GEMM arguments object
CutlassGemm::Arguments args(
{M, N, K}, // GEMM problem dimensions
{A, lda}, // Tensor-ref for source matrix A
{B, ldb}, // Tensor-ref for source matrix B
{C, ldc}, // Tensor-ref for source matrix C
{D, ldc}, // Tensor-ref for destination matrix D (may be different memory
// than source C matrix)
{alpha, beta}); // Scalars used in the epilogue

// Create and launch the CUTLASS GEMM kernel
// D = alpha * A x B + beta * C
const auto status = CutlassGemm()(args);

if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("CUTLASS GEMM kernel failed: ") +
std::string(cudaGetErrorString(cudaErrorUnknown)));
}

return TD;
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"sgemm_float(float alpha, Tensor TA, Tensor TB, float beta, Tensor TC) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl(
"sgemm_float",
torch::dispatch(
c10::DispatchKey::CUDA,
TORCH_FN(fbgemm_gpu::experimental::sgemm_float_cuda)));
}

} // namespace fbgemm_gpu::experimental
17 changes: 17 additions & 0 deletions fbgemm_gpu/experimental/example/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Tuple

import torch


gpu_unavailable: Tuple[bool, str] = (
not torch.cuda.is_available() or torch.cuda.device_count() == 0,
"CUDA is not available or no GPUs detected",
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fbgemm_gpu.experimental.example import utils


class ExampleTest(unittest.TestCase):
class AddTensorsFloatTest(unittest.TestCase):
def test_add_tensors_float(self) -> None:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
Expand Down
33 changes: 33 additions & 0 deletions fbgemm_gpu/experimental/example/test/sgemm_float_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
from fbgemm_gpu.experimental.example import utils

from . import gpu_unavailable


class SgemmFloatTest(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
def test_sgemm_float(self) -> None:
alpha = 3.14
beta = 2.71

A = torch.rand(4, 3, dtype=torch.float, device="cuda")
B = torch.rand(3, 5, dtype=torch.float, device="cuda")
C = torch.rand(4, 5, dtype=torch.float, device="cuda")
D = utils.sgemm(alpha, A, B, beta, C)

expected = torch.add(alpha * torch.matmul(A, B), beta * C)
torch.testing.assert_close(D.cpu(), expected.cpu())


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
from torch import SymInt, Tensor


if hasattr(torch.library, "impl_abstract"):
if hasattr(torch.library, "register_fake"):
# pyre-ignore[9]
impl_abstract = torch.library.register_fake
elif hasattr(torch.library, "impl_abstract"):
impl_abstract = torch.library.impl_abstract
else:
# pyre-ignore
Expand Down

0 comments on commit ca4e84b

Please sign in to comment.