Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove accelerate from langchain #1743

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 39 additions & 58 deletions apps/language_models/langchain/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,48 @@ def get_config(
return_model=False,
raise_exception=False,
):
from accelerate import init_empty_weights
from transformers import AutoConfig

with init_empty_weights():
from transformers import AutoConfig

try:
config = AutoConfig.from_pretrained(
base_model,
use_auth_token=use_auth_token,
trust_remote_code=trust_remote_code,
offload_folder=offload_folder,
)
except OSError as e:
if raise_exception:
raise
if "not a local folder and is not a valid model identifier listed on" in str(
e
) or "404 Client Error" in str(
e
):
# e.g. llama, gpjt, etc.
# e.g. HF TGI but not model on HF or private etc.
# HF TGI server only should really require prompt_type, not HF model state
return None, None
else:
raise
if triton_attn and "mpt-" in base_model.lower():
config.attn_config["attn_impl"] = "triton"
if long_sequence:
if "mpt-7b-storywriter" in base_model.lower():
config.update({"max_seq_len": 83968})
if "mosaicml/mpt-7b-chat" in base_model.lower():
config.update({"max_seq_len": 4096})
if "mpt-30b" in base_model.lower():
config.update({"max_seq_len": 2 * 8192})
if return_model and issubclass(
config.__class__, tuple(AutoModel._model_mapping.keys())
try:
config = AutoConfig.from_pretrained(
base_model,
use_auth_token=use_auth_token,
trust_remote_code=trust_remote_code,
offload_folder=offload_folder,
)
except OSError as e:
if raise_exception:
raise
if "not a local folder and is not a valid model identifier listed on" in str(
e
) or "404 Client Error" in str(
e
):
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
# e.g. llama, gpjt, etc.
# e.g. HF TGI but not model on HF or private etc.
# HF TGI server only should really require prompt_type, not HF model state
return None, None
else:
# can't infer
model = None
raise
if triton_attn and "mpt-" in base_model.lower():
config.attn_config["attn_impl"] = "triton"
if long_sequence:
if "mpt-7b-storywriter" in base_model.lower():
config.update({"max_seq_len": 83968})
if "mosaicml/mpt-7b-chat" in base_model.lower():
config.update({"max_seq_len": 4096})
if "mpt-30b" in base_model.lower():
config.update({"max_seq_len": 2 * 8192})
if return_model and issubclass(
config.__class__, tuple(AutoModel._model_mapping.keys())
):
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
else:
# can't infer
model = None
if "falcon" in base_model.lower():
config.use_cache = False

Expand All @@ -177,22 +174,6 @@ def get_non_lora_model(
"""

device_map = None
if model is not None:
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
# NOTE: Some models require avoiding sharding some layers,
# then would pass no_split_module_classes and give list of those layers.
from accelerate import infer_auto_device_map

device_map = infer_auto_device_map(
model,
dtype=torch.float16 if load_half else torch.float32,
)
if hasattr(model, "model"):
device_map_model = infer_auto_device_map(
model.model,
dtype=torch.float16 if load_half else torch.float32,
)
device_map.update(device_map_model)

n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0

Expand Down
2 changes: 1 addition & 1 deletion apps/language_models/langchain/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def get_embedding(
from langchain.embeddings import HuggingFaceEmbeddings

torch_dtype, context_class = get_dtype()
model_kwargs = dict(device=args.device)
model_kwargs = dict(device="cpu")
if "instructor" in hf_embedding_model:
encode_kwargs = {"normalize_embeddings": True}
embedding = HuggingFaceInstructEmbeddings(
Expand Down
2 changes: 1 addition & 1 deletion apps/language_models/langchain/langchain_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pandas==2.0.2
matplotlib==3.7.1
loralib==0.1.1
bitsandbytes==0.39.0
accelerate==0.20.3
# accelerate==0.20.3
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
transformers==4.30.2
Expand Down
12 changes: 6 additions & 6 deletions apps/stable_diffusion/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def resource_path(relative_path):
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# h2ogpt_upload,
# h2ogpt_web,
h2ogpt_upload,
h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
Expand Down Expand Up @@ -248,10 +248,10 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
upscaler_status,
]
)
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="DocuChat Upload", id=11):
h2ogpt_upload.render()
with gr.TabItem(label="DocuChat(Experimental)", id=12):
h2ogpt_web.render()

# send to buttons
register_button_click(
Expand Down
4 changes: 4 additions & 0 deletions apps/stable_diffusion/web/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
llm_chat_api,
)
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.h2ogpt import (
h2ogpt_upload,
h2ogpt_web,
)
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
Expand Down
4 changes: 1 addition & 3 deletions apps/stable_diffusion/web/ui/h2ogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ def chat(curr_system_message, history, device, precision):

upload_path = None
database = None
database_directory = os.path.abspath(
"apps/language_models/langchain/db_path/"
)
database_directory = "db_dir_UserData"

def read_path():
global upload_path
Expand Down
Loading