Skip to content

Commit

Permalink
model info update
Browse files Browse the repository at this point in the history
  • Loading branch information
jankinf committed Jun 24, 2024
1 parent 3d7025d commit c83a588
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 53 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ celerybeat.pid
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
model:
model_name: 'llava-v1.5-7b'
model_path: './playground/model_weights/LVIS-Instruct4V-Nodetail-mix619k-7b'
# model_path: './playground/model_weights/LVIS-Instruct4V-Nodetail-mix619k-7b'
model_path: 'X2FD/LVIS-Instruct4V-Nodetail-mix619k-7b'

parameters:
sep: ','
Expand Down
3 changes: 2 additions & 1 deletion mmte/configs/models/llava/llava-1.5-13b.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
model:
model_name: 'llava-v1.5-13b'
model_path: './playground/model_weights/llava-v1.5-13b'
# model_path: './playground/model_weights/llava-v1.5-13b'
model_path: 'liuhaotian/llava-v1.5-13b'

parameters:
sep: ','
Expand Down
3 changes: 2 additions & 1 deletion mmte/configs/models/llava/llava-1.5-7b.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
model:
model_name: 'llava-v1.5-7b'
model_path: './playground/model_weights/llava-v1.5-7b'
# model_path: './playground/model_weights/llava-v1.5-7b'
model_path: 'liuhaotian/llava-v1.5-7b'

parameters:
sep: ','
Expand Down
4 changes: 2 additions & 2 deletions mmte/configs/models/llava/llava-rlhf-13b.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model:
model_name: 'llava-rlhf-13b'
model_path: './playground/model_weights/llava-rlhf-13b/sft_model'
lora_path: "./playground/model_weights/llava-rlhf-13b/rlhf_lora_adapter_model"
model_path : 'zhiqings/LLaVA-RLHF-13b-v1.5-336'

parameters:
sep: ','
temperature: 0
Expand Down
1 change: 0 additions & 1 deletion mmte/configs/models/openai/openai.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
api_key: ""
proxy: "127.0.0.1:7890"
max_retries: 10
timeout: 1
21 changes: 12 additions & 9 deletions mmte/evaluators/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def pearson_corr(y_true, y_pred, nan_to_num: Optional[Union[float, int]] = None)
corr = np.nan
return corr


def failure(y_true, y_pred, fails_num: Optional[Union[float, int]] = np.nan):
# Calculate the proportion of occurrences of fails_num in the y_pred sequence.
x = np.array(y_pred, dtype=np.float32)
Expand All @@ -39,25 +40,26 @@ def failure(y_true, y_pred, fails_num: Optional[Union[float, int]] = np.nan):
failure = (x == fails_num).sum() / x.size
return failure

def iou_judge(self, box1_list, box2_list):
#TODO:目前只针对单个box计算,并不是对所有全部结果的统一计算,需要添加遍历和判断逻辑(iou>0.5则cnt+=1)

def iou_judge(box1_list, box2_list):
cnt = 0
for i in range(len(box1_list)):
x1_min, y1_min, x1_max, y1_max = box1[i]
x2_min, y2_min, x2_max, y2_max = box2[i]
# print(x1_min, y1_min, x1_max, y1_max, x2_min, y2_min, x2_max, y2_max)
box_len = len(box1_list)
for i in range(box_len):
x1_min, y1_min, x1_max, y1_max = box1_list[i]
x2_min, y2_min, x2_max, y2_max = box2_list[i]

x_inter_min = max(x1_min, x2_min)
y_inter_min = max(y1_min, y2_min)
x_inter_max = min(x1_max, x2_max)
y_inter_max = min(y1_max, y2_max)
# print(x_inter_min,y_inter_min,x_inter_max,y_inter_max)

inter_width = max(0, x_inter_max - x_inter_min)
inter_height = max(0, y_inter_max - y_inter_min)
inter_area = inter_width * inter_height
box1_area = (x1_max - x1_min) * (y1_max - y1_min)
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
union_area = box1_area + box2_area - inter_area
# print("area",inter_area, box1_area, box2_area)

if inter_area == 0 or union_area == 0:
iou = 0
else:
Expand All @@ -66,7 +68,8 @@ def iou_judge(self, box1_list, box2_list):
if iou > 0.5:
print("success grounding!")
cnt += 1
grounding_rate = (cnt*1.0/len(box1_list)) *100.0

grounding_rate = (cnt * 1.0 / box_len) * 100.0
return grounding_rate

_supported_metrics = {
Expand Down
4 changes: 0 additions & 4 deletions mmte/evaluators/score_eval.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import os
import json
import time
from tqdm import tqdm
from typing import Any, Sequence, List, Tuple, Dict
from mmte.evaluators.base import BaseEvaluator
from mmte.utils.registry import registry


@registry.register_evaluator()
class PerspectiveAPIEvaluator(BaseEvaluator):
# https://perspectiveapi.com/
Expand Down
24 changes: 9 additions & 15 deletions mmte/models/llava_rlhf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,18 @@
BitsAndBytesConfig,
)
from mmte.models.llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN,
)
from mmte.models.llava.conversation import conv_templates, SeparatorStyle
from mmte.models.llava.model.builder import load_pretrained_model
from mmte.models.llava.utils import disable_torch_init
from mmte.models.llava.mm_utils import (
tokenizer_image_token,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
from mmte.models.llava.model import *
from mmte.models.llava.eval.run_llava import chat_model

from PIL import Image
import math
from peft import PeftModel
import time
from glob import glob
from huggingface_hub.constants import HF_HUB_CACHE
from huggingface_hub import snapshot_download

@registry.register_chatmodel()
class LLaVARLHFChat(BaseChat):
"""
Expand All @@ -53,8 +44,11 @@ def __init__(self, model_id: str, device: str="cuda:0"):
print(self.config)

self.model_name = self.config.model.model_name
model_path = self.config.model.model_path
lora_path = self.config.model.lora_path
download_path = self.config.model.model_path
snapshot_download(repo_id=download_path, force_download=False)
model_path = glob("{}/models--zhiqings--LLaVA-RLHF-13b-v1.5-336/snapshots/*/sft_model".format(HF_HUB_CACHE))[0]
lora_path = glob("{}/models--zhiqings--LLaVA-RLHF-13b-v1.5-336/snapshots/*/rlhf_lora_adapter_model".format(HF_HUB_CACHE))[0]

self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
bits = 16
self.dtype = torch.bfloat16
Expand Down
1 change: 0 additions & 1 deletion mmte/models/mplug_owl2_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class mPLUGOwl2Chat(BaseChat):
Chat class for mPLUG-Owl2 models
"""

# TODO: update model config
MODEL_CONFIG = {
"mplug-owl2-llama2-7b": 'configs/models/mplug-owl2/mplug-owl2-llama2-7b.yaml',
}
Expand Down
1 change: 0 additions & 1 deletion mmte/models/mplug_owl_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class mPLUGOwlChat(BaseChat):
Chat class for mPLUG-Owl models
"""

# TODO: update model config
MODEL_CONFIG = {
"mplug-owl-llama-7b": 'configs/models/mplug-owl/mplug-owl-llama-7b.yaml',
}
Expand Down
1 change: 0 additions & 1 deletion mmte/models/sharegpt4v_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class ShareGPT4VChat(BaseChat):
Chat class for ShareGPT4V models
"""

# TODO: update model config
MODEL_CONFIG = {
"ShareGPT4V-7B": 'configs/models/sharegpt4v/ShareGPT4V-7B.yaml',
"ShareGPT4V-13B": 'configs/models/sharegpt4v/ShareGPT4V-13B.yaml',
Expand Down
37 changes: 22 additions & 15 deletions run_task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sys
sys.path.append('/data/zhangyichi/MMTrustEval-dev/chang/')

import yaml
import argparse
import warnings
from pprint import pprint
warnings.filterwarnings("ignore")
from mmte.tasks.base import BaseTask
from mmte.utils.registry import registry
from mmte.evaluators.metrics import _supported_metrics
from mmte.utils.utils import DictAction, merge_config
import argparse
import yaml

def parse_args():
parser = argparse.ArgumentParser()
Expand All @@ -24,35 +24,42 @@ def parse_args():
args = parser.parse_args()
return args

print("models: ", registry.list_chatmodels())
print("datasets: ", registry.list_datasets())
print("methods: ", registry.list_methods())
print("evaluators: ", registry.list_evaluators())
print("metrics: ", list(_supported_metrics.keys()))


if __name__ == '__main__':
'''
# List all available modules:
pprint("models: ", registry.list_chatmodels())
pprint("datasets: ", registry.list_datasets())
pprint("methods: ", registry.list_methods())
pprint("evaluators: ", registry.list_evaluators())
pprint("metrics: ", list(_supported_metrics.keys()))
'''

args = parse_args()
config = args.config

with open(config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
if hasattr(args, "cfg_options") and args.cfg_options is not None:
cfg = merge_config(cfg, args.cfg_options)
print(config)
print(cfg)

model_id = cfg.get('model_id')
dataset_id = cfg.get('dataset_id')
log_file = cfg.get('log_file')
method_cfg = cfg.get('method_cfg', {})
dataset_cfg = cfg.get('dataset_cfg', {})
generation_kwargs = cfg.get('generation_kwargs', {})
evaluator_seq_cfgs = cfg.get('evaluator_seq_cfgs', [])

if 'max_new_tokens' not in generation_kwargs.keys():
generation_kwargs['max_new_tokens'] = 50
if 'do_sample' not in generation_kwargs.keys():
generation_kwargs['do_sample'] = False


cfg['generation_kwargs'] = generation_kwargs
cfg['config_path'] = config

pprint(cfg, width=150)

runner = BaseTask(dataset_id=dataset_id, model_id=model_id, method_cfg=method_cfg, dataset_cfg=dataset_cfg, generation_kwargs=generation_kwargs, log_file=log_file, evaluator_seq_cfgs=evaluator_seq_cfgs)
runner.pipeline()

0 comments on commit c83a588

Please sign in to comment.