diff --git a/apps/language_models/langchain/gen.py b/apps/language_models/langchain/gen.py index 3b96523559..3d5c36e745 100644 --- a/apps/language_models/langchain/gen.py +++ b/apps/language_models/langchain/gen.py @@ -109,51 +109,48 @@ def get_config( return_model=False, raise_exception=False, ): - from accelerate import init_empty_weights + from transformers import AutoConfig - with init_empty_weights(): - from transformers import AutoConfig - - try: - config = AutoConfig.from_pretrained( - base_model, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - ) - except OSError as e: - if raise_exception: - raise - if "not a local folder and is not a valid model identifier listed on" in str( - e - ) or "404 Client Error" in str( - e - ): - # e.g. llama, gpjt, etc. - # e.g. HF TGI but not model on HF or private etc. - # HF TGI server only should really require prompt_type, not HF model state - return None, None - else: - raise - if triton_attn and "mpt-" in base_model.lower(): - config.attn_config["attn_impl"] = "triton" - if long_sequence: - if "mpt-7b-storywriter" in base_model.lower(): - config.update({"max_seq_len": 83968}) - if "mosaicml/mpt-7b-chat" in base_model.lower(): - config.update({"max_seq_len": 4096}) - if "mpt-30b" in base_model.lower(): - config.update({"max_seq_len": 2 * 8192}) - if return_model and issubclass( - config.__class__, tuple(AutoModel._model_mapping.keys()) + try: + config = AutoConfig.from_pretrained( + base_model, + use_auth_token=use_auth_token, + trust_remote_code=trust_remote_code, + offload_folder=offload_folder, + ) + except OSError as e: + if raise_exception: + raise + if "not a local folder and is not a valid model identifier listed on" in str( + e + ) or "404 Client Error" in str( + e ): - model = AutoModel.from_config( - config, - trust_remote_code=trust_remote_code, - ) + # e.g. llama, gpjt, etc. + # e.g. HF TGI but not model on HF or private etc. + # HF TGI server only should really require prompt_type, not HF model state + return None, None else: - # can't infer - model = None + raise + if triton_attn and "mpt-" in base_model.lower(): + config.attn_config["attn_impl"] = "triton" + if long_sequence: + if "mpt-7b-storywriter" in base_model.lower(): + config.update({"max_seq_len": 83968}) + if "mosaicml/mpt-7b-chat" in base_model.lower(): + config.update({"max_seq_len": 4096}) + if "mpt-30b" in base_model.lower(): + config.update({"max_seq_len": 2 * 8192}) + if return_model and issubclass( + config.__class__, tuple(AutoModel._model_mapping.keys()) + ): + model = AutoModel.from_config( + config, + trust_remote_code=trust_remote_code, + ) + else: + # can't infer + model = None if "falcon" in base_model.lower(): config.use_cache = False @@ -177,22 +174,6 @@ def get_non_lora_model( """ device_map = None - if model is not None: - # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model - # NOTE: Some models require avoiding sharding some layers, - # then would pass no_split_module_classes and give list of those layers. - from accelerate import infer_auto_device_map - - device_map = infer_auto_device_map( - model, - dtype=torch.float16 if load_half else torch.float32, - ) - if hasattr(model, "model"): - device_map_model = infer_auto_device_map( - model.model, - dtype=torch.float16 if load_half else torch.float32, - ) - device_map.update(device_map_model) n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 diff --git a/apps/language_models/langchain/gpt_langchain.py b/apps/language_models/langchain/gpt_langchain.py index 6c2db5ec51..ab29d4410d 100644 --- a/apps/language_models/langchain/gpt_langchain.py +++ b/apps/language_models/langchain/gpt_langchain.py @@ -372,7 +372,7 @@ def get_embedding( from langchain.embeddings import HuggingFaceEmbeddings torch_dtype, context_class = get_dtype() - model_kwargs = dict(device=args.device) + model_kwargs = dict(device="cpu") if "instructor" in hf_embedding_model: encode_kwargs = {"normalize_embeddings": True} embedding = HuggingFaceInstructEmbeddings( diff --git a/apps/language_models/langchain/langchain_requirements.txt b/apps/language_models/langchain/langchain_requirements.txt index 78bd6e7562..fe1ddcff45 100644 --- a/apps/language_models/langchain/langchain_requirements.txt +++ b/apps/language_models/langchain/langchain_requirements.txt @@ -16,7 +16,7 @@ pandas==2.0.2 matplotlib==3.7.1 loralib==0.1.1 bitsandbytes==0.39.0 -accelerate==0.20.3 +# accelerate==0.20.3 peft==0.4.0 # 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026) transformers==4.30.2 diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 9f15a380e9..4756d22859 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -115,8 +115,8 @@ def resource_path(relative_path): txt2img_sendto_inpaint, txt2img_sendto_outpaint, txt2img_sendto_upscaler, - # h2ogpt_upload, - # h2ogpt_web, + h2ogpt_upload, + h2ogpt_web, img2img_web, img2img_custom_model, img2img_hf_model_id, @@ -248,10 +248,10 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): upscaler_status, ] ) - # with gr.TabItem(label="DocuChat Upload", id=11): - # h2ogpt_upload.render() - # with gr.TabItem(label="DocuChat(Experimental)", id=12): - # h2ogpt_web.render() + with gr.TabItem(label="DocuChat Upload", id=11): + h2ogpt_upload.render() + with gr.TabItem(label="DocuChat(Experimental)", id=12): + h2ogpt_web.render() # send to buttons register_button_click( diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py index f2c266bac1..08c42273df 100644 --- a/apps/stable_diffusion/web/ui/__init__.py +++ b/apps/stable_diffusion/web/ui/__init__.py @@ -79,6 +79,10 @@ llm_chat_api, ) from apps.stable_diffusion.web.ui.generate_config import model_config_web +from apps.stable_diffusion.web.ui.h2ogpt import ( + h2ogpt_upload, + h2ogpt_web, +) from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web from apps.stable_diffusion.web.ui.outputgallery_ui import ( outputgallery_web, diff --git a/apps/stable_diffusion/web/ui/h2ogpt.py b/apps/stable_diffusion/web/ui/h2ogpt.py index e39b4134b7..50b402b1d5 100644 --- a/apps/stable_diffusion/web/ui/h2ogpt.py +++ b/apps/stable_diffusion/web/ui/h2ogpt.py @@ -274,9 +274,7 @@ def chat(curr_system_message, history, device, precision): upload_path = None database = None - database_directory = os.path.abspath( - "apps/language_models/langchain/db_path/" - ) + database_directory = "db_dir_UserData" def read_path(): global upload_path