Skip to content

Commit

Permalink
Remove accelerate from langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Aug 9, 2023
1 parent 96185c9 commit 470dac3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 59 deletions.
97 changes: 39 additions & 58 deletions apps/language_models/langchain/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,48 @@ def get_config(
return_model=False,
raise_exception=False,
):
from accelerate import init_empty_weights
from transformers import AutoConfig

with init_empty_weights():
from transformers import AutoConfig

try:
config = AutoConfig.from_pretrained(
base_model,
use_auth_token=use_auth_token,
trust_remote_code=trust_remote_code,
offload_folder=offload_folder,
)
except OSError as e:
if raise_exception:
raise
if "not a local folder and is not a valid model identifier listed on" in str(
e
) or "404 Client Error" in str(
e
):
# e.g. llama, gpjt, etc.
# e.g. HF TGI but not model on HF or private etc.
# HF TGI server only should really require prompt_type, not HF model state
return None, None
else:
raise
if triton_attn and "mpt-" in base_model.lower():
config.attn_config["attn_impl"] = "triton"
if long_sequence:
if "mpt-7b-storywriter" in base_model.lower():
config.update({"max_seq_len": 83968})
if "mosaicml/mpt-7b-chat" in base_model.lower():
config.update({"max_seq_len": 4096})
if "mpt-30b" in base_model.lower():
config.update({"max_seq_len": 2 * 8192})
if return_model and issubclass(
config.__class__, tuple(AutoModel._model_mapping.keys())
try:
config = AutoConfig.from_pretrained(
base_model,
use_auth_token=use_auth_token,
trust_remote_code=trust_remote_code,
offload_folder=offload_folder,
)
except OSError as e:
if raise_exception:
raise
if "not a local folder and is not a valid model identifier listed on" in str(
e
) or "404 Client Error" in str(
e
):
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
# e.g. llama, gpjt, etc.
# e.g. HF TGI but not model on HF or private etc.
# HF TGI server only should really require prompt_type, not HF model state
return None, None
else:
# can't infer
model = None
raise
if triton_attn and "mpt-" in base_model.lower():
config.attn_config["attn_impl"] = "triton"
if long_sequence:
if "mpt-7b-storywriter" in base_model.lower():
config.update({"max_seq_len": 83968})
if "mosaicml/mpt-7b-chat" in base_model.lower():
config.update({"max_seq_len": 4096})
if "mpt-30b" in base_model.lower():
config.update({"max_seq_len": 2 * 8192})
if return_model and issubclass(
config.__class__, tuple(AutoModel._model_mapping.keys())
):
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
else:
# can't infer
model = None
if "falcon" in base_model.lower():
config.use_cache = False

Expand All @@ -177,22 +174,6 @@ def get_non_lora_model(
"""

device_map = None
if model is not None:
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
# NOTE: Some models require avoiding sharding some layers,
# then would pass no_split_module_classes and give list of those layers.
from accelerate import infer_auto_device_map

device_map = infer_auto_device_map(
model,
dtype=torch.float16 if load_half else torch.float32,
)
if hasattr(model, "model"):
device_map_model = infer_auto_device_map(
model.model,
dtype=torch.float16 if load_half else torch.float32,
)
device_map.update(device_map_model)

n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0

Expand Down
2 changes: 1 addition & 1 deletion apps/language_models/langchain/langchain_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pandas==2.0.2
matplotlib==3.7.1
loralib==0.1.1
bitsandbytes==0.39.0
accelerate==0.20.3
# accelerate==0.20.3
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
transformers==4.30.2
Expand Down

0 comments on commit 470dac3

Please sign in to comment.