diff --git a/apps/language_models/langchain/gen.py b/apps/language_models/langchain/gen.py index 3b96523559..3d5c36e745 100644 --- a/apps/language_models/langchain/gen.py +++ b/apps/language_models/langchain/gen.py @@ -109,51 +109,48 @@ def get_config( return_model=False, raise_exception=False, ): - from accelerate import init_empty_weights + from transformers import AutoConfig - with init_empty_weights(): - from transformers import AutoConfig - - try: - config = AutoConfig.from_pretrained( - base_model, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - ) - except OSError as e: - if raise_exception: - raise - if "not a local folder and is not a valid model identifier listed on" in str( - e - ) or "404 Client Error" in str( - e - ): - # e.g. llama, gpjt, etc. - # e.g. HF TGI but not model on HF or private etc. - # HF TGI server only should really require prompt_type, not HF model state - return None, None - else: - raise - if triton_attn and "mpt-" in base_model.lower(): - config.attn_config["attn_impl"] = "triton" - if long_sequence: - if "mpt-7b-storywriter" in base_model.lower(): - config.update({"max_seq_len": 83968}) - if "mosaicml/mpt-7b-chat" in base_model.lower(): - config.update({"max_seq_len": 4096}) - if "mpt-30b" in base_model.lower(): - config.update({"max_seq_len": 2 * 8192}) - if return_model and issubclass( - config.__class__, tuple(AutoModel._model_mapping.keys()) + try: + config = AutoConfig.from_pretrained( + base_model, + use_auth_token=use_auth_token, + trust_remote_code=trust_remote_code, + offload_folder=offload_folder, + ) + except OSError as e: + if raise_exception: + raise + if "not a local folder and is not a valid model identifier listed on" in str( + e + ) or "404 Client Error" in str( + e ): - model = AutoModel.from_config( - config, - trust_remote_code=trust_remote_code, - ) + # e.g. llama, gpjt, etc. + # e.g. HF TGI but not model on HF or private etc. + # HF TGI server only should really require prompt_type, not HF model state + return None, None else: - # can't infer - model = None + raise + if triton_attn and "mpt-" in base_model.lower(): + config.attn_config["attn_impl"] = "triton" + if long_sequence: + if "mpt-7b-storywriter" in base_model.lower(): + config.update({"max_seq_len": 83968}) + if "mosaicml/mpt-7b-chat" in base_model.lower(): + config.update({"max_seq_len": 4096}) + if "mpt-30b" in base_model.lower(): + config.update({"max_seq_len": 2 * 8192}) + if return_model and issubclass( + config.__class__, tuple(AutoModel._model_mapping.keys()) + ): + model = AutoModel.from_config( + config, + trust_remote_code=trust_remote_code, + ) + else: + # can't infer + model = None if "falcon" in base_model.lower(): config.use_cache = False @@ -177,22 +174,6 @@ def get_non_lora_model( """ device_map = None - if model is not None: - # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model - # NOTE: Some models require avoiding sharding some layers, - # then would pass no_split_module_classes and give list of those layers. - from accelerate import infer_auto_device_map - - device_map = infer_auto_device_map( - model, - dtype=torch.float16 if load_half else torch.float32, - ) - if hasattr(model, "model"): - device_map_model = infer_auto_device_map( - model.model, - dtype=torch.float16 if load_half else torch.float32, - ) - device_map.update(device_map_model) n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 diff --git a/apps/language_models/langchain/langchain_requirements.txt b/apps/language_models/langchain/langchain_requirements.txt index 78bd6e7562..fe1ddcff45 100644 --- a/apps/language_models/langchain/langchain_requirements.txt +++ b/apps/language_models/langchain/langchain_requirements.txt @@ -16,7 +16,7 @@ pandas==2.0.2 matplotlib==3.7.1 loralib==0.1.1 bitsandbytes==0.39.0 -accelerate==0.20.3 +# accelerate==0.20.3 peft==0.4.0 # 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026) transformers==4.30.2