diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index fe2516b80..bca0fd5f1 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -5,6 +5,13 @@ on: push: branches: - nightly + - main + # Release candidate branch look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-release+ + tags: + # Release candidate tag look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - v[0-9]+.[0-9]+.[0-9]+ workflow_dispatch: jobs: diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index f18f78bfb..b0d7b56b3 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -24,7 +24,7 @@ from tabulate import tabulate -def generate_package_version(package_name: str): +def generate_package_version(package_name: str, version_variant: str): print("[SETUP.PY] Generating the package version ...") if "nightly" in package_name: @@ -47,7 +47,7 @@ def generate_package_version(package_name: str): # Remove the local version identifier, if any (e.g. 0.4.0rc0.post0+git.6a63116c.dirty => 0.4.0rc0.post0) # Then remove post0 (keep postN for N > 0) (e.g. 0.4.0rc0.post0 => 0.4.0rc0) version = re.sub(".post0$", "", gitversion.version_from_git().split("+")[0]) - + version = str(version) + version_variant print(f"[SETUP.PY] Setting the package version: {version}") return version @@ -279,14 +279,21 @@ def main(argv: List[str]) -> None: else: package_name = args.package_name - if not args.cpu_only: + if args.cpu_only: + version_variant = "+cpu" + else: set_cuda_environment_variables() + if torch.version.cuda is not None: + cuda_version = torch.version.cuda.split(".") + version_variant = "+cu" + str(cuda_version[0]) + str(cuda_version[1]) + else: + version_variant = "" # Repair command line args for setup. sys.argv = [sys.argv[0]] + unknown # Determine the package version - package_version = generate_package_version(args.package_name) + package_version = generate_package_version(args.package_name, version_variant) # Generate the version file FbgemmGpuInstaller.generate_version_file(package_version)