Skip to content

Commit

Permalink
Try cibuildwheel.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Mar 18, 2024
1 parent f4ff626 commit e475f56
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 67 deletions.
70 changes: 12 additions & 58 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,17 @@ jobs:
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12']
jax-version: ['0.4.24']
cuda-version: ['11.8.0', '12.3.1']
python-version: ['cp39', 'cp310', 'cp311', 'cp312']
cuda-version: ['11.8', '12.3']

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_JAX_VERSION=$(echo ${{ matrix.jax-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
Expand All @@ -77,53 +67,17 @@ jobs:
with:
swap-size-gb: 10

- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/[email protected]
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'

- name: Install Jax ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
pip install --upgrade "jax[cuda${MATRIX_CUDA_MAJOR}_local] == ${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
shell:
bash

- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
# setuptools-cuda-cpp on pypi has a bug that breaks ninja
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
pip install ninja packaging wheel pybind11
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Set MAX_JOBS to allocate 8GB per job, which should be enough to build comfortably
free -h
export MAX_JOBS=3
echo "Building with ${MAX_JOBS} jobs"
python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}jax${{ matrix.jax-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
shell:
bash
- name: Build wheels
uses: pypa/[email protected]
env:
CIBW_BUILD: ${{ matrix.python-version }}-manylinux_x86_64
CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014_x86_64_cuda_${{ matrix.cuda-version }}

- name: Log Built Wheels
run: |
ls dist
ls wheelhouse
wheel_name=$(basename wheelhouse/*.whl)
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Get the tag version
id: extract_branch
Expand All @@ -144,15 +98,15 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_path: ./wheelhouse/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*

- name: Upload Artifact
uses: actions/upload-artifact@v4
with:
name: ${{env.wheel_name}}
path: ./dist/${{env.wheel_name}}
path: ./wheelhouse/${{env.wheel_name}}

publish_package:
name: Publish package
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[build-system]
requires = ["setuptools", "wheel", "setuptools-cuda-cpp", "packaging", "pybind11"]
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]
32 changes: 24 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from setuptools import setup, find_packages
from setuptools_cuda_cpp import CUDAExtension, BuildExtension, fix_dll
# from setuptools_cuda.inspections import find_cuda_home
import pybind11

import subprocess
Expand Down Expand Up @@ -53,14 +52,31 @@ def get_platform():
else:
raise ValueError("Unsupported platform: {}".format(sys.platform))


def get_cuda_bare_metal_version(cuda_dir):
def locate_cuda():
if 'sdist' in sys.argv:
return None
cuda_dir = os.environ.get("CUDA_HOME", None)
if cuda_dir is None:
if os.path.exists("/usr/local/cuda"):
cuda_dir = "/usr/local/cuda"
os.environ["CUDA_HOME"] = cuda_dir
elif os.path.exists("/opt/cuda"):
cuda_dir = "/opt/cuda"
os.environ["CUDA_HOME"] = cuda_dir
else:
raise RuntimeError("CUDA_HOME not set and no CUDA installation found")
return cuda_dir


def get_cuda_version():
cuda_dir = locate_cuda()
if cuda_dir is None:
return ""
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_version
version = output[release_idx].split(",")[0].split('.')[0] # should be 11 or 12
return f'+cu{version}'


def append_nvcc_threads(nvcc_extra_args):
Expand Down Expand Up @@ -180,9 +196,9 @@ def get_package_version():
public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
if local_version:
return f"{public_version}+{local_version}"
return f"{public_version}+{local_version}{get_cuda_version()}"
else:
return str(public_version)
return f"{public_version}{get_cuda_version()}"


class NinjaBuildExtension(BuildExtension):
Expand Down

0 comments on commit e475f56

Please sign in to comment.