From 2276c6e1905ad5fa2316b62edd7cd0fff0b4ab41 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 17 May 2024 12:43:03 +0200 Subject: [PATCH] FIX BOFT setting env vars breaks C++ compilation (#1739) Resolves #1738 --- src/peft/tuners/boft/layer.py | 59 ++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 7473d32e17..192a2350e0 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -20,6 +20,7 @@ import math import os import warnings +from contextlib import contextmanager from typing import Any, Optional, Union import torch @@ -31,13 +32,46 @@ from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge -os.environ["CC"] = "gcc" -os.environ["CXX"] = "gcc" -curr_dir = os.path.dirname(__file__) - _FBD_CUDA = None +# this function is a 1:1 copy from accelerate +@contextmanager +def patch_environment(**kwargs): + """ + A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting. + + Will convert the values in `kwargs` to strings and upper-case all the keys. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import patch_environment + + >>> with patch_environment(FOO="bar"): + ... print(os.environ["FOO"]) # prints "bar" + >>> print(os.environ["FOO"]) # raises KeyError + ``` + """ + existing_vars = {} + for key, value in kwargs.items(): + key = key.upper() + if key in os.environ: + existing_vars[key] = os.environ[key] + os.environ[key] = str(value) + + yield + + for key in kwargs: + key = key.upper() + if key in existing_vars: + # restore previous value + os.environ[key] = existing_vars[key] + else: + os.environ.pop(key, None) + + def get_fbd_cuda(): global _FBD_CUDA @@ -47,14 +81,15 @@ def get_fbd_cuda(): curr_dir = os.path.dirname(__file__) # need ninja to build the extension try: - fbd_cuda = load( - name="fbd_cuda", - sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"], - verbose=True, - # build_directory='/tmp/' # for debugging - ) - # extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7 - import fbd_cuda + with patch_environment(CC="gcc", CXX="gcc"): + fbd_cuda = load( + name="fbd_cuda", + sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"], + verbose=True, + # build_directory='/tmp/' # for debugging + ) + # extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7 + import fbd_cuda except Exception as e: warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.") warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.")