Skip to content

Commit

Permalink
Build cuda12 version with Hopper support.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed May 24, 2024
1 parent 047b437 commit 4b704af
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 38 deletions.
63 changes: 37 additions & 26 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ jobs:
- name: Checkout
uses: actions/checkout@v3

- name: Set up python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
Expand Down Expand Up @@ -76,6 +81,7 @@ jobs:

- name: Log Built Wheels
run: |
python3 set_tag_in_wheels.py "+cu$MATRIX_CUDA_MAJOR" wheelhouse/*.whl
ls wheelhouse
wheel_name=$(basename wheelhouse/*.whl)
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
Expand Down Expand Up @@ -109,36 +115,41 @@ jobs:
name: ${{env.wheel_name}}
path: ./wheelhouse/${{env.wheel_name}}

# publish_package:
# name: Publish package
# needs: [build_wheels]
publish_package:
name: Publish package
needs: [build_wheels]

# runs-on: ubuntu-latest
# permissions:
# id-token: write
runs-on: ubuntu-latest
permissions:
id-token: write

# steps:
# - uses: actions/checkout@v3
steps:
- uses: actions/checkout@v3

# - uses: actions/setup-python@v4
# with:
# python-version: '3.10'
- uses: actions/setup-python@v4
with:
python-version: '3.10'

# - name: Install dependencies
# run: |
# pip install setuptools==68.0.0
# pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
# pip install ninja packaging wheel pybind11
- name: Install dependencies
run: |
pip install setuptools==68.0.0
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
pip install ninja packaging wheel pybind11
# - name: Build core package
# run: |
# CUDA_HOME=/ python setup.py sdist --dist-dir=dist
- name: Build core package
run: |
CUDA_HOME=/ python setup.py sdist --dist-dir=dist
# - name: Retrieve release distributions
# uses: actions/download-artifact@v4
# with:
# path: dist/
# merge-multiple: true
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
path: dist/
merge-multiple: true
pattern: '*+cu12*.whl'

- name: Remove version tag for pypi
run: |
python3 set_tag_in_wheels.py "" dist/*+cu12*.whl
# - name: Publish release distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
15 changes: 15 additions & 0 deletions set_tag_in_wheels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import sys

cuda_ver = sys.argv[1]
wheels = sys.argv[2:]
for wheel in wheels:
dirname = os.path.dirname(wheel)
basename = os.path.basename(wheel)
parts = basename.split('-')
if len(parts) != 5:
continue
version = parts[1].split('+')[0]
parts[1] = f'{version}{cuda_ver}'
basename = '-'.join(parts)
os.rename(wheel, f'{dirname}/{basename}')
19 changes: 8 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def get_cuda_version():
output = raw_output.split()
release_idx = output.index("release") + 1
version = output[release_idx].split(",")[0].split('.')[0] # should be 11 or 12
return f'+cu{version}'

return version

def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args + ["--threads", "4"]
Expand All @@ -92,6 +91,8 @@ def append_nvcc_threads(nvcc_extra_args):

SKIP_CUDA_BUILD = False

cuda12 = get_cuda_version() != "11"

# CUDA_HOME = find_cuda_home()
if not SKIP_CUDA_BUILD:

Expand All @@ -100,10 +101,9 @@ def append_nvcc_threads(nvcc_extra_args):
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
# if CUDA_HOME is not None:
# if bare_metal_version >= Version("11.8"):
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_90,code=sm_90")
if cuda12:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
ext_modules.append(
CUDAExtension(
name="flash_attn_jax_lib.flash_api",
Expand Down Expand Up @@ -171,6 +171,7 @@ def append_nvcc_threads(nvcc_extra_args):
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-DFLASHATTENTION_DISABLE_DROPOUT=1",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
Expand All @@ -194,11 +195,7 @@ def get_package_version():
with open(Path(this_dir) / "src" / "flash_attn_jax" / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
if local_version:
return f"{public_version}+{local_version}{get_cuda_version()}"
else:
return f"{public_version}{get_cuda_version()}"
return public_version


class NinjaBuildExtension(BuildExtension):
Expand Down
2 changes: 1 addition & 1 deletion src/flash_attn_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .flash import flash_mha
__version__ = 'v0.2.1'
__version__ = 'v0.2.2'

0 comments on commit 4b704af

Please sign in to comment.