Skip to content

Commit

Permalink
multi-turn disable for some models
Browse files Browse the repository at this point in the history
  • Loading branch information
jankinf authored and Aries-iai committed Jul 15, 2024
1 parent 730e033 commit 2d0a1d0
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 24 deletions.
1 change: 1 addition & 0 deletions mmte/models/instructblip_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def chat(self, messages: List, **generation_kwargs):
img_list = []
multimodal = False
# TODO: if system message provided.
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/internlm_xcomposer_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):

@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/internvl_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):
@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
# TODO: if system message provided.
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/llava_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):
@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
# TODO: if system message provided.
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/llava_rlhf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):
@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
# TODO: if system message provided.
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/lrv_instruction_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def chat(self, messages: List, **generation_kwargs):
conversation = self.CONV_DICT[self.model_id].copy()
img_list = []
# TODO: if system message provided.
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/minigpt4_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def chat(self, messages: List, **generation_kwargs):
conversation = self.CONV_DICT[self.model_id].copy()
img_list = []
# TODO: if system message provided.
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/mplug_owl2_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):

@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/mplug_owl_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):

@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/otter_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):

@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
1 change: 1 addition & 0 deletions mmte/models/qwen_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):

@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down
34 changes: 10 additions & 24 deletions mmte/models/sharegpt4v_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, model_id: str, device: str="cuda:0"):
config = self.MODEL_CONFIG[self.model_id]
self.config = OmegaConf.load(get_abs_path(config))
self.device = device
# print(self.config.model.model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=self.config.model.model_path,
model_base=None,
Expand All @@ -47,6 +46,7 @@ def __init__(self, model_id: str, device: str="cuda:0"):

@torch.no_grad()
def chat(self, messages: List, **generation_kwargs):
assert len(messages) == 1, 'Only support one-turn conversation currently'
for message in messages:
if message["role"] in ["system", "user", "assistant"]:
if message["role"] == "user":
Expand Down Expand Up @@ -76,12 +76,7 @@ def chat(self, messages: List, **generation_kwargs):
image_tensor = None
user_message = message["content"]
qs = user_message
if self.model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \
DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs


conv_mode = "share4v_v0"
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
Expand All @@ -101,23 +96,14 @@ def chat(self, messages: List, **generation_kwargs):
keywords, self.tokenizer, input_ids)

with torch.inference_mode():
if isinstance(message["content"], dict):
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=generation_kwargs.get("do_sample"),
temperature=0.2,
max_new_tokens=generation_kwargs.get("max_new_tokens"),
use_cache=True,
stopping_criteria=[stopping_criteria])
else:
output_ids = self.model.generate(
input_ids,
do_sample=generation_kwargs.get("do_sample"),
temperature=0.2,
max_new_tokens=generation_kwargs.get("max_new_tokens"),
use_cache=True,
stopping_criteria=[stopping_criteria])
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=generation_kwargs.get("do_sample"),
temperature=0.2,
max_new_tokens=generation_kwargs.get("max_new_tokens"),
use_cache=True,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (
Expand Down

0 comments on commit 2d0a1d0

Please sign in to comment.