Skip to content

Commit

Permalink
Support LLaVA 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Dec 21, 2023
1 parent e95d9ac commit 29f88a7
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 169 deletions.
30 changes: 6 additions & 24 deletions configs/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct

<!-- [TABS-BEGIN] -->

**Prepare the checkpoint**

According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
script to download and get the merged the checkpoint.

```shell
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
```

**Use the model**

```python
import torch
from mmpretrain import get_model, inference_model

model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda')
print(out)
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
```

**Test Command**

Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).

Test:

```shell
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
```

<!-- [TABS-END] -->

## Models and results

### Image Caption on COCO

| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
| Model | Params (M) | Config | Download |
| :-------------------- | :--------: | :------------------------------: | :--------------------: |
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |

## Citation

Expand Down
76 changes: 76 additions & 0 deletions configs/llava/llava-7b-v1.5_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'

meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
Describe the image in detail. ASSISTANT:'''

# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
)

# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]

test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)

# schedule settings
test_cfg = dict()
76 changes: 76 additions & 0 deletions configs/llava/llava-7b-v1.5_vqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'

meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
{{question}} ASSISTANT:'''

# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='vqa',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=100),
)

# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id', 'question']),
]

test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)

# schedule settings
test_cfg = dict()
22 changes: 9 additions & 13 deletions configs/llava/llava-7b-v1_caption.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
_base_ = '../_base_/default_runtime.py'

meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
im_patch_token = '<im_patch>'
patch_size = 14
image_size = 224
num_patches = (image_size // patch_size)**2
caption_prompt = ' '.join([
meta_prompt,
'User: a photo of\n',
im_patch_token * num_patches,
'ASSISTANT:',
])
prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end>
Describe the image in detail. ASSISTANT:'''

# model settings
model = dict(
Expand All @@ -22,6 +15,7 @@
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
Expand All @@ -32,15 +26,17 @@
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
mm_hidden_size=1024,
use_im_start_end=False,
use_mm_proj=True,
use_im_patch=False,
use_im_start_end=True,
mm_proj_depth=1,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=caption_prompt,
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
prompt_tmpl=prompt_tmpl,
# generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
generation_cfg=dict(max_new_tokens=50),
)

# data settings
Expand Down
28 changes: 27 additions & 1 deletion configs/llava/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,31 @@ Models:
Metrics:
BLEU-4: null
CIDER: null
Weights: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth
Config: configs/llava/llava-7b-v1_caption.py
- Name: llava-7b-v1.5_caption
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_caption.py
- Name: llava-7b-v1.5_vqa
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Visual Question Answering
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_vqa.py
34 changes: 21 additions & 13 deletions mmpretrain/models/multimodal/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class Llava(BaseModel):
use_im_start_end (bool): Whether to use the im_start and im_end tokens
mm_vision_select_layer (int): The index from vision encoder output.
Defaults to -1.
use_mm_proj (bool): Whether to enable multi-modal projection.
Defaults to True.
mm_proj_depth (int): The number of linear layers for multi-modal
projection. Defaults to 1.
load_lang_pretrained (bool): Whether to load the pretrained model of
language encoder. Defaults to False.
generation_cfg (dict): The extra generation config, accept the keyword
Expand All @@ -51,9 +51,10 @@ def __init__(self,
mm_hidden_size: int,
prompt_tmpl: str,
task: str = 'caption',
use_im_patch: bool = True,
use_im_start_end: bool = False,
mm_vision_select_layer: int = -1,
use_mm_proj: bool = True,
mm_proj_depth: int = 1,
generation_cfg: dict = dict(),
load_lang_pretrained: bool = False,
data_preprocessor: Optional[dict] = None,
Expand All @@ -75,7 +76,8 @@ def __init__(self,
# init tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
# add Llava special tokens to the tokenizer
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
if use_im_patch:
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
if use_im_start_end:
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
special_tokens=True)
Expand Down Expand Up @@ -108,14 +110,12 @@ def __init__(self,
vision_encoder=vision_encoder,
lang_encoder=lang_encoder,
mm_hidden_size=mm_hidden_size,
use_mm_proj=use_mm_proj,
mm_proj_depth=mm_proj_depth,
use_im_start_end=use_im_start_end,
im_start_token=self.tokenizer.convert_tokens_to_ids(
self.im_start_token),
im_end_token=self.tokenizer.convert_tokens_to_ids(
self.im_end_token),
im_patch_token=self.tokenizer.convert_tokens_to_ids(
self.im_patch_token),
mm_vision_select_layer=mm_vision_select_layer)

self.generation_cfg = generation_cfg
Expand Down Expand Up @@ -207,16 +207,24 @@ def preprocess_text(self, data_samples: List[DataSample],
Returns:
List[DataSample]: Return list of data samples.
"""
prompts = []
tokens = []
for sample in data_samples:
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
prompts.append(final_prompt)
prompt = self.prompt_tmpl.format(**sample.to_dict())
input_ids = []
while '<image>' in prompt:
prefix, _, prompt = prompt.partition('<image>')
input_ids.extend(
self.tokenizer(prefix, add_special_tokens=False).input_ids)
input_ids.append(-200)
if prompt:
input_ids.extend(
self.tokenizer(prompt, add_special_tokens=False).input_ids)
tokens.append(dict(input_ids=input_ids))

self.tokenizer.padding_side = 'left'
input_text = self.tokenizer(
prompts,
input_text = self.tokenizer.pad(
tokens,
padding='longest',
truncation=True,
return_tensors='pt',
max_length=2000,
).to(device)
Expand Down
Loading

0 comments on commit 29f88a7

Please sign in to comment.