Skip to content

Commit

Permalink
Fix CUTLASS builds (pytorch#2725)
Browse files Browse the repository at this point in the history
Summary:
- Enable sm_90a compilation so that CUTLASS builds run correctly on H100

Pull Request resolved: pytorch#2725

Reviewed By: jianyuh

Differential Revision: D58485566

Pulled By: q10

fbshipit-source-id: 4adfe2c9fa2b2003bae09ce5c6afd92bb9959421
  • Loading branch information
q10 authored and facebook-github-bot committed Jun 12, 2024
1 parent 35fed06 commit 771a065
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
19 changes: 13 additions & 6 deletions .github/scripts/fbgemm_gpu_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,20 @@ __configure_fbgemm_gpu_build_cuda () {
local arch_list="${TORCH_CUDA_ARCH_LIST}"

else
# Build only CUDA 7.0, 8.0, and 9.0 (i.e. V100, A100, H100) because of 100 MB binary size limits from PyPI.
echo "[BUILD] Using the default CUDA targets ..."
# For cuda version 12.1, enable sm 9.0
# To keep binary sizes to minimum, build only against the CUDA architectures
# that the latest PyTorch supports:
# 7.0 (V100), 8.0 (A100), and 9.0,9.0a (H100)
cuda_version_nvcc=$(conda run -n "${env_name}" nvcc --version)
echo "$cuda_version_nvcc"
if [[ $cuda_version_nvcc == *"V12.1"* ]]; then
local arch_list="7.0;8.0;9.0"
echo "[BUILD] Using the default architectures for CUDA $cuda_version_nvcc ..."

if [[ $cuda_version_nvcc == *"V12.1"* ]] || [[ $cuda_version_nvcc == *"V12.4"* ]]; then
# sm_90 and sm_90a are only available for CUDA 12.1+
# NOTE: CUTLASS kernels for Hopper require sm_90a to be enabled
# See:
# https://github.com/NVIDIA/nvbench/discussions/129
# https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L187
# https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp#L224
local arch_list="7.0;8.0;9.0;9.0a"
else
local arch_list="7.0;8.0"
fi
Expand Down
1 change: 1 addition & 0 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ __determine_test_directories () {
for test_dir in "${target_directories[@]}"; do
echo "$test_dir"
done
echo ""
}

test_all_fbgemm_gpu_modules () {
Expand Down
4 changes: 2 additions & 2 deletions .github/scripts/nova_dir.bash
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ export BUILD_FROM_NOVA=1
## Overwrite existing ENV VAR in Nova
if [[ "$CONDA_ENV" != "" ]]; then export CONDA_RUN="conda run --no-capture-output -p ${CONDA_ENV}" && echo "$CONDA_RUN"; fi
if [[ "$CU_VERSION" == "cu118" ]]; then export TORCH_CUDA_ARCH_LIST='7.0;8.0' && echo "$TORCH_CUDA_ARCH_LIST"; fi
if [[ "$CU_VERSION" == "cu121" ]]; then export TORCH_CUDA_ARCH_LIST='7.0;8.0;9.0' && echo "$TORCH_CUDA_ARCH_LIST"; fi
if [[ "$CU_VERSION" == "cu124" ]]; then export TORCH_CUDA_ARCH_LIST='8.0;9.0' && echo "$TORCH_CUDA_ARCH_LIST"; fi
if [[ "$CU_VERSION" == "cu121" ]]; then export TORCH_CUDA_ARCH_LIST='7.0;8.0;9.0;9.0a' && echo "$TORCH_CUDA_ARCH_LIST"; fi
if [[ "$CU_VERSION" == "cu124" ]]; then export TORCH_CUDA_ARCH_LIST='8.0;9.0;9.0a' && echo "$TORCH_CUDA_ARCH_LIST"; fi
2 changes: 1 addition & 1 deletion .github/scripts/utils_pip.bash
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ install_from_pytorch_pip () {
# Ensure that the package build is of the correct variant
# This test usually applies to the nightly builds
# shellcheck disable=SC2086
if conda run ${env_prefix} pip list "${package_name}" | grep "${package_name}" | grep "${package_variant}"; then
if conda run ${env_prefix} pip list | grep "${package_name}" | grep "${package_variant}"; then
echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] is the correct variant (${package_variant})"
else
echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] appears to be an incorrect variant as it is missing references to ${package_variant}!"
Expand Down

0 comments on commit 771a065

Please sign in to comment.