Skip to content

Commit

Permalink
[MiniGPT4] Add MiniGPT4 to SHARK
Browse files Browse the repository at this point in the history
-- This is the first installment of MiniGPT4 in SHARK.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma committed Jul 17, 2023
1 parent c471d17 commit 5ac977d
Show file tree
Hide file tree
Showing 17 changed files with 3,898 additions and 1 deletion.
433 changes: 433 additions & 0 deletions apps/language_models/src/model_wrappers/minigpt4.py

Large diffs are not rendered by default.

1,157 changes: 1,157 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_pipeline.py

Large diffs are not rendered by default.

1,297 changes: 1,297 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_utils/Qformer.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode


class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return

def __call__(self, item):
return self.transform(item)

@classmethod
def from_config(cls, cfg=None):
return cls()

def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)

return self.from_config(cfg)


class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)

self.normalize = transforms.Normalize(mean, std)


class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)

self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)

def __call__(self, item):
return self.transform(item)

@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()

image_size = cfg.get("image_size", 224)

mean = cfg.get("mean", None)
std = cfg.get("std", None)

return cls(image_size=image_size, mean=mean, std=std)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
datasets:
cc_sbu_align:
data_type: images
build_info:
storage: /path/to/cc_sbu_align/
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model:
arch: mini_gpt4

# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
freeze_qformer: True

# Q-Former
num_query_token: 32

# Vicuna
llama_model: "lmsys/vicuna-7b-v1.3"

# generation configs
prompt: ""

preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: 'prerained_minigpt4_7b.pth'


datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"

run:
task: image_text_pretrain
53 changes: 53 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os

import torch
import torch.distributed as dist
import timm.models.hub as timm_hub


def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True


def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()


def is_main_process():
return get_rank() == 0


def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""

def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)

return cached_file

if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)

if is_dist_avail_and_initialized():
dist.barrier()

return get_cached_file_path()
Loading

0 comments on commit 5ac977d

Please sign in to comment.