Skip to content

Commit

Permalink
[MiniGPT4] Add MiniGPT4 to SHARK
Browse files Browse the repository at this point in the history
-- This is the first installment of MiniGPT4 in SHARK.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma committed Jul 7, 2023
1 parent 3a24cff commit 14a2d31
Show file tree
Hide file tree
Showing 15 changed files with 4,102 additions and 1 deletion.
433 changes: 433 additions & 0 deletions apps/language_models/src/model_wrappers/minigpt4.py

Large diffs are not rendered by default.

1,188 changes: 1,188 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_pipeline.py

Large diffs are not rendered by default.

1,308 changes: 1,308 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_utils/Qformer.py

Large diffs are not rendered by default.

156 changes: 156 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import re

from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode


class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return

def __call__(self, item):
return self.transform(item)

@classmethod
def from_config(cls, cfg=None):
return cls()

def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)

return self.from_config(cfg)


class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)

self.normalize = transforms.Normalize(mean, std)


class BlipCaptionProcessor(BaseProcessor):
def __init__(self, prompt="", max_words=50):
self.prompt = prompt
self.max_words = max_words

def __call__(self, caption):
caption = self.prompt + self.pre_caption(caption)

return caption

@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()

prompt = cfg.get("prompt", "")
max_words = cfg.get("max_words", 50)

return cls(prompt=prompt, max_words=max_words)

def pre_caption(self, caption):
caption = re.sub(
r"([.!\"()*#:;~])",
" ",
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
" ",
caption,
)
caption = caption.rstrip("\n")
caption = caption.strip(" ")

# truncate caption
caption_words = caption.split(" ")
if len(caption_words) > self.max_words:
caption = " ".join(caption_words[: self.max_words])

return caption


class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
def __init__(
self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0
):
super().__init__(mean=mean, std=std)

self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)

def __call__(self, item):
return self.transform(item)

@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()

image_size = cfg.get("image_size", 224)

mean = cfg.get("mean", None)
std = cfg.get("std", None)

min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)

return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
)


class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)

self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)

def __call__(self, item):
return self.transform(item)

@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()

image_size = cfg.get("image_size", 224)

mean = cfg.get("mean", None)
std = cfg.get("std", None)

return cls(image_size=image_size, mean=mean, std=std)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
datasets:
cc_sbu_align:
data_type: images
build_info:
storage: /path/to/cc_sbu_align/
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model:
arch: mini_gpt4

# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
freeze_qformer: True

# Q-Former
num_query_token: 32

# Vicuna
llama_model: "/home/abhishek/vicuna_weights/vicuna_weights/"

# generation configs
prompt: ""

preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/home/abhishek/prerained_minigpt4_7b.pth'


datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"

run:
task: image_text_pretrain
137 changes: 137 additions & 0 deletions apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import functools
import os

import torch
import torch.distributed as dist
import timm.models.hub as timm_hub


def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__

builtin_print = __builtin__.print

def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)

__builtin__.print = print


def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True


def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()


def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()


def is_main_process():
return get_rank() == 0


def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return

args.distributed = True

torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
"| distributed init (rank {}, world {}): {}".format(
args.rank, args.world_size, args.dist_url
),
flush=True,
)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=datetime.timedelta(
days=365
), # allow auto-downloading and de-compressing
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)


def get_dist_info():
if torch.__version__ < "1.0":
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else: # non-distributed training
rank = 0
world_size = 1
return rank, world_size


def main_process(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)

return wrapper


def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""

def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)

return cached_file

if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)

if is_dist_avail_and_initialized():
dist.barrier()

return get_cached_file_path()
Loading

0 comments on commit 14a2d31

Please sign in to comment.