Skip to content

Commit

Permalink
Add separate binaries for linux cu102 (#927)
Browse files Browse the repository at this point in the history
  • Loading branch information
zou3519 authored Jun 30, 2022
1 parent 014d6f2 commit 36641d7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ jobs:

build-wheel-linux:
runs-on: ubuntu-18.04
container: pytorch/manylinux-cpu
strategy:
matrix:
python_version: [["3.7", "cp37-cp37m"], ["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
cuda_support: [["", "--extra-index-url https://download.pytorch.org/whl/cpu", "\"['cpu', 'cu113', 'cu116']\"", "cpu"], ["+cu102", "", "\"['cu102']\"", "cuda102"]]
container: pytorch/manylinux-${{ matrix.cuda_support[3] }}
steps:
- name: Checkout functorch
uses: actions/checkout@v2
- name: Install PyTorch 1.12 RC
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install torch==1.12 --extra-index-url https://download.pytorch.org/whl/cpu
python3 -mpip install torch==1.12 ${{ matrix.cuda_support[1] }}
- name: Build wheel
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install wheel
python3 setup.py bdist_wheel
VERSION_TAG=${{ matrix.cuda_support[0] }} PYTORCH_CUDA_RESTRICTIONS=${{ matrix.cuda_support[2] }} python3 setup.py bdist_wheel
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
Expand Down
38 changes: 25 additions & 13 deletions functorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@
# LICENSE file in the root directory of this source tree.
import torch

if torch.version.cuda == '10.2':
raise RuntimeError(
"We've detected an installation of PyTorch 1.12 with CUDA 10.2 support. "
"The official functorch 0.2.0 binaries are not compatible with the "
"PyTorch 1.12 CUDA 10.2 binaries. "
"Please install PyTorch with support for a different version of CUDA "
"(either cpu-only, 11.3, or 11.6; see pytorch.org for instructions) or "
"file an issue on GitHub to discuss more.")

try:
from .version import __version__ # noqa: F401
from .version import pytorch_cuda_restrictions # noqa: F401

if pytorch_cuda_restrictions is not None:
if torch.version.cuda is None:
torch_cuda_version = 'cpu'
verbose_torch_cuda_version = f'cpuonly'
else:
torch_cuda_version = torch.version.cuda
verbose_torch_cuda_version = f'CUDA {torch.version.cuda}'

if torch_cuda_version not in pytorch_cuda_restrictions:
raise RuntimeError(
f"We've detected an installation of PyTorch 1.12 with {verbose_torch_cuda_version} support. "
"This functorch 0.2.0 binary is not compatible with the PyTorch installation. "
"Please see our install page for suggestions on how to resolve this: "
"https://pytorch.org/functorch/stable/install.html")

# don't leak variables
del torch_cuda_version
del verbose_torch_cuda_version
del pytorch_cuda_restrictions
except ImportError:
pass

from . import _C

Expand Down Expand Up @@ -43,7 +59,3 @@
FunctionalModuleWithBuffers,
)

try:
from .version import __version__ # noqa: F401
except ImportError:
pass
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@
# elif sha != 'Unknown':
# version += '+' + sha[:7]

if os.getenv('VERSION_TAG'):
version_tag = os.getenv('VERSION_TAG')
version = f'{version}{version_tag}'

pytorch_cuda_restrictions = None
if os.getenv('PYTORCH_CUDA_RESTRICTIONS'):
pytorch_cuda_restrictions = os.getenv('PYTORCH_CUDA_RESTRICTIONS')

def write_version_file():
version_path = os.path.join(cwd, 'functorch', 'version.py')
with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha)))
f.write("pytorch_cuda_restrictions = {}\n".format(pytorch_cuda_restrictions))


# pytorch_dep = 'torch'
Expand Down

0 comments on commit 36641d7

Please sign in to comment.