Skip to content

Commit

Permalink
update llava dependent folder and upgrade transformers 4.33.2 -> 4.37.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jankinf authored and Aries-iai committed Jul 8, 2024
2 parents ae09538 + 7c6ee70 commit aca47ef
Show file tree
Hide file tree
Showing 50 changed files with 2,814 additions and 2,370 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,7 @@ playground
log_reload.py

*competition*
apikey_local.yml
apikey_local.yml
llava-old/
mmte/models/llava/serve/
test_env/
2 changes: 1 addition & 1 deletion env/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ninja==1.11.1.1
torch==2.3.0+cu121
torchvision==0.18.0+cu121
tqdm==4.66.1
transformers==4.33.2
transformers==4.37.2
webdataset==0.2.86
datasets==2.18.0
openai==1.9.0
Expand Down
12 changes: 5 additions & 7 deletions mmte/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from mmte.utils.registry import registry
from mmte.models.base import BaseChat, Response
import transformers
from mmte.models.llava_chat import LLaVAChat
from mmte.models.llava_rlhf_chat import LLaVARLHFChat
from mmte.models.mplug_owl2_chat import mPLUGOwl2Chat
from mmte.models.internvl_chat import InternVLChat
from mmte.models.lrv_instruction_chat import LRVInstructionChat
from mmte.models.openai_chat import OpenAIChat
from mmte.models.google_chat import GoogleChat
from mmte.models.claude3_chat import ClaudeChat
from mmte.models.qwen_plus_chat import QwenPlusChat
from mmte.models.minigpt4_chat import MiniGPT4Chat
from mmte.models.instructblip_chat import InstructBLIPChat
if str(transformers.__version__)<'4.37.2':
from mmte.models.llava_chat import LLaVAChat
from mmte.models.llava_rlhf_chat import LLaVARLHFChat
from mmte.models.mplug_owl2_chat import mPLUGOwl2Chat
from mmte.models.internvl_chat import InternVLChat
from mmte.models.qwen_chat import QwenChat
from mmte.models.otter_chat import OtterChat
from mmte.models.mplug_owl_chat import mPLUGOwlChat
from mmte.models.internlm_xcomposer_chat import InternLMXComposerChat
from mmte.models.lrv_instruction_chat import LRVInstructionChat
from mmte.models.sharegpt4v_chat import ShareGPT4VChat
from mmte.models.cogvlm_chat import CogVLMChat
from mmte.models.phi3_chat import Phi3Chat
Expand Down
135 changes: 75 additions & 60 deletions mmte/models/llava/conversation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import base64
from io import BytesIO
from PIL import Image


class SeparatorStyle(Enum):
Expand Down Expand Up @@ -68,7 +71,7 @@ def get_prompt(self):
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""

Expand Down Expand Up @@ -106,79 +109,66 @@ def get_prompt(self):
def append_message(self, role, message):
self.messages.append([role, message])

def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if max(image.size) > max_len:
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
return image
else:
buffered = BytesIO()
image.save(buffered, format=image_format)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str

def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
image = self.process_image(image, image_process_mode, return_pil=return_pil)
images.append(image)
return images

def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
img_b64_str = self.process_image(
image, "Default", return_pil=False,
image_format='JPEG')
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<image>', '').strip()
ret.append([msg, None])
else:
Expand Down Expand Up @@ -357,13 +347,38 @@ def dict(self):
version="v1_mmtag",
)

conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)

conv_chatml_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,

"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
Expand Down
113 changes: 113 additions & 0 deletions mmte/models/llava/eval/eval_gpt_review.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import argparse
import json
import os

import openai
import tqdm
import ray
import time

NUM_SECONDS_TO_SLEEP = 3

@ray.remote(num_cpus=4)
def get_eval(content: str, max_tokens: int):
while True:
try:
response = openai.ChatCompletion.create(
model='gpt-4',
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
}, {
'role': 'user',
'content': content,
}],
temperature=0.2, # TODO: figure out which temperature is best for evaluation
max_tokens=max_tokens,
)
break
except openai.error.RateLimitError:
pass
except Exception as e:
print(e)
time.sleep(NUM_SECONDS_TO_SLEEP)

print('success!')
return response['choices'][0]['message']['content']


def parse_score(review):
try:
score_pair = review.split('\n')[0]
score_pair = score_pair.replace(',', ' ')
sp = score_pair.split(' ')
if len(sp) == 2:
return [float(sp[0]), float(sp[1])]
else:
print('error', review)
return [-1, -1]
except Exception as e:
print(e)
print('error', review)
return [-1, -1]


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
parser.add_argument('-q', '--question')
# parser.add_argument('-a', '--answer')
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
parser.add_argument('-r', '--rule')
parser.add_argument('-o', '--output')
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
args = parser.parse_args()

ray.init()

f_q = open(os.path.expanduser(args.question))
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))

review_file = open(f'{args.output}', 'w')

js_list = []
handles = []
idx = 0
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
# if idx == 1:
# break

ques = json.loads(ques_js)
ans1 = json.loads(ans1_js)
ans2 = json.loads(ans2_js)

category = json.loads(ques_js)['category']
if category in rule_dict:
rule = rule_dict[category]
else:
rule = rule_dict['default']
prompt = rule['prompt']
role = rule['role']
content = (f'[Question]\n{ques["text"]}\n\n'
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
f'[System]\n{prompt}\n\n')
js_list.append({
'id': idx+1,
'question_id': ques['question_id'],
'answer1_id': ans1['answer_id'],
'answer2_id': ans2['answer_id'],
'category': category})
idx += 1
handles.append(get_eval.remote(content, args.max_tokens))
# To avoid the rate limit set by OpenAI
time.sleep(NUM_SECONDS_TO_SLEEP)

reviews = ray.get(handles)
for idx, review in enumerate(reviews):
scores = parse_score(review)
js_list[idx]['content'] = review
js_list[idx]['tuple'] = scores
review_file.write(json.dumps(js_list[idx]) + '\n')
review_file.close()
Loading

0 comments on commit aca47ef

Please sign in to comment.