Skip to content

Commit

Permalink
Update pytorch-triton version (pytorch#2775)
Browse files Browse the repository at this point in the history
Summary:
- Update pytorch-triton version to follow PyTorch pytorch/pytorch#126098

Pull Request resolved: pytorch#2775

Reviewed By: brad-mengchi

Differential Revision: D58959021

Pulled By: q10

fbshipit-source-id: 575d42dc07020778ef9f201ea1a234fb92be27f2
  • Loading branch information
q10 authored and facebook-github-bot committed Jun 24, 2024
1 parent cdad003 commit d38bca1
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 11 deletions.
30 changes: 23 additions & 7 deletions .github/scripts/utils_triton.bash
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Triton Setup Functions
################################################################################

install_triton_gitmodule () {
install_triton_git_repo () {
local env_name="$1"
local triton_version="$2"
if [ "$env_name" == "" ]; then
Expand All @@ -26,7 +26,7 @@ install_triton_gitmodule () {
return 1
else
echo "################################################################################"
echo "# Build + Install Triton (gitmodule)"
echo "# Build + Install Triton (git repo)"
echo "#"
echo "# [$(date --utc +%FT%T.%3NZ)] + ${FUNCNAME[0]} ${*}"
echo "################################################################################"
Expand All @@ -39,20 +39,23 @@ install_triton_gitmodule () {
local env_prefix=$(env_name_or_prefix "${env_name}")

echo "[BUILD] Checking out triton ..."
cd ../third_party/triton/python || return 1
cd ~/ || return 1
git clone https://github.com/triton-lang/triton.git || return 1
cd triton/python || return 1

if [ "$triton_version" != "" ]; then
(print_exec git checkout "${triton_version}") || return 1
(print_exec git reset --hard "${triton_version}") || return 1
fi

echo "[BUILD] Installing Triton from gitmodule ..."
echo "[BUILD] Installing Triton from git repo ..."
# shellcheck disable=SC2086
(exec_with_retries 3 conda run --no-capture-output ${env_prefix} python -m pip install -e .) || return 1

# shellcheck disable=SC2086
(test_python_import_package "${env_name}" triton) || return 1

cd - || return 1
echo "[INSTALL] Successfully installed Triton ${triton_version} from gitmodule"
echo "[INSTALL] Successfully installed Triton ${triton_version} from git repo"
}

install_triton_pip () {
Expand All @@ -72,8 +75,21 @@ install_triton_pip () {
fi

echo "[BUILD] Installing Triton from PIP ..."
# NOTE: Install pytorch-triton from nightly with commit SHA that follows the
# version tracked in FB internal (/third-party/tp2/triton/2.0.x/METADATA.bzl)
# as close as possible.
#
# https://download.pytorch.org/whl/nightly/pytorch-triton/
# https://download.pytorch.org/whl/nightly/pytorch-triton-rocm/
#
# This needs to be manually updated to follow PyTorch's triton pinning
# updates:
#
# https://github.com/pytorch/pytorch/blob/main/.ci/docker/ci_commit_pins/triton.txt
# https://github.com/pytorch/pytorch/pull/126098
#
# shellcheck disable=SC2086
install_from_pytorch_pip "${env_name}" pytorch-triton nightly/3.0.0+45fff310c8 || return 1
install_from_pytorch_pip "${env_name}" pytorch-triton nightly/3.0.0+dedb7bdf33 || return 1

# shellcheck disable=SC2086
(test_python_import_package "${env_name}" triton) || return 1
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "third_party/triton"]
path = third_party/triton
url = https://github.com/openai/triton.git
23 changes: 23 additions & 0 deletions fbgemm_gpu/docs/src/fbgemm_gpu-development/BuildInstructions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,29 @@ For the CUDA variant of PyTorch, verify that at the minimum ``cuda_cmake_macros.
conda_prefix=$(conda run -n ${env_name} printenv CONDA_PREFIX)
find "${conda_prefix}" -name cuda_cmake_macros.h
Install PyTorch-Triton
~~~~~~~~~~~~~~~~~~~~~~

This section is only applicable to building the experimental FBGEMM_GPU
Triton-GEMM module. Triton should be installed via the ``pytorch-triton``,
which generally comes installing ``torch``, but can also be installed manually:

.. code:: sh
# pytorch-triton repos:
# https://download.pytorch.org/whl/nightly/pytorch-triton/
# https://download.pytorch.org/whl/nightly/pytorch-triton-rocm/
# The version SHA should follow the one pinned in PyTorch
# https://github.com/pytorch/pytorch/blob/main/.ci/docker/ci_commit_pins/triton.txt
conda run -n ${env_name} pip install --pre pytorch-triton==3.0.0+dedb7bdf33 --index-url https://download.pytorch.org/whl/nightly/
Verify the PyTorch-Triton installation with an ``import`` test:

.. code:: sh
# Ensure that the package loads properly
conda run -n ${env_name} python -c "import triton"
Build the FBGEMM_GPU Package
----------------------------
Expand Down
1 change: 0 additions & 1 deletion third_party/triton
Submodule triton deleted from 45fff3

0 comments on commit d38bca1

Please sign in to comment.