-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
-- This is the first installment of MiniGPT4 in SHARK. Signed-off-by: Abhishek Varma <[email protected]>
- Loading branch information
1 parent
3a24cff
commit 14a2d31
Showing
15 changed files
with
4,102 additions
and
1 deletion.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
1,188 changes: 1,188 additions & 0 deletions
1,188
apps/language_models/src/pipelines/minigpt4_pipeline.py
Large diffs are not rendered by default.
Oops, something went wrong.
1,308 changes: 1,308 additions & 0 deletions
1,308
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Large diffs are not rendered by default.
Oops, something went wrong.
156 changes: 156 additions & 0 deletions
156
apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import re | ||
|
||
from omegaconf import OmegaConf | ||
from torchvision import transforms | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
|
||
class BaseProcessor: | ||
def __init__(self): | ||
self.transform = lambda x: x | ||
return | ||
|
||
def __call__(self, item): | ||
return self.transform(item) | ||
|
||
@classmethod | ||
def from_config(cls, cfg=None): | ||
return cls() | ||
|
||
def build(self, **kwargs): | ||
cfg = OmegaConf.create(kwargs) | ||
|
||
return self.from_config(cfg) | ||
|
||
|
||
class BlipImageBaseProcessor(BaseProcessor): | ||
def __init__(self, mean=None, std=None): | ||
if mean is None: | ||
mean = (0.48145466, 0.4578275, 0.40821073) | ||
if std is None: | ||
std = (0.26862954, 0.26130258, 0.27577711) | ||
|
||
self.normalize = transforms.Normalize(mean, std) | ||
|
||
|
||
class BlipCaptionProcessor(BaseProcessor): | ||
def __init__(self, prompt="", max_words=50): | ||
self.prompt = prompt | ||
self.max_words = max_words | ||
|
||
def __call__(self, caption): | ||
caption = self.prompt + self.pre_caption(caption) | ||
|
||
return caption | ||
|
||
@classmethod | ||
def from_config(cls, cfg=None): | ||
if cfg is None: | ||
cfg = OmegaConf.create() | ||
|
||
prompt = cfg.get("prompt", "") | ||
max_words = cfg.get("max_words", 50) | ||
|
||
return cls(prompt=prompt, max_words=max_words) | ||
|
||
def pre_caption(self, caption): | ||
caption = re.sub( | ||
r"([.!\"()*#:;~])", | ||
" ", | ||
caption.lower(), | ||
) | ||
caption = re.sub( | ||
r"\s{2,}", | ||
" ", | ||
caption, | ||
) | ||
caption = caption.rstrip("\n") | ||
caption = caption.strip(" ") | ||
|
||
# truncate caption | ||
caption_words = caption.split(" ") | ||
if len(caption_words) > self.max_words: | ||
caption = " ".join(caption_words[: self.max_words]) | ||
|
||
return caption | ||
|
||
|
||
class Blip2ImageTrainProcessor(BlipImageBaseProcessor): | ||
def __init__( | ||
self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0 | ||
): | ||
super().__init__(mean=mean, std=std) | ||
|
||
self.transform = transforms.Compose( | ||
[ | ||
transforms.RandomResizedCrop( | ||
image_size, | ||
scale=(min_scale, max_scale), | ||
interpolation=InterpolationMode.BICUBIC, | ||
), | ||
transforms.ToTensor(), | ||
self.normalize, | ||
] | ||
) | ||
|
||
def __call__(self, item): | ||
return self.transform(item) | ||
|
||
@classmethod | ||
def from_config(cls, cfg=None): | ||
if cfg is None: | ||
cfg = OmegaConf.create() | ||
|
||
image_size = cfg.get("image_size", 224) | ||
|
||
mean = cfg.get("mean", None) | ||
std = cfg.get("std", None) | ||
|
||
min_scale = cfg.get("min_scale", 0.5) | ||
max_scale = cfg.get("max_scale", 1.0) | ||
|
||
return cls( | ||
image_size=image_size, | ||
mean=mean, | ||
std=std, | ||
min_scale=min_scale, | ||
max_scale=max_scale, | ||
) | ||
|
||
|
||
class Blip2ImageEvalProcessor(BlipImageBaseProcessor): | ||
def __init__(self, image_size=224, mean=None, std=None): | ||
super().__init__(mean=mean, std=std) | ||
|
||
self.transform = transforms.Compose( | ||
[ | ||
transforms.Resize( | ||
(image_size, image_size), | ||
interpolation=InterpolationMode.BICUBIC, | ||
), | ||
transforms.ToTensor(), | ||
self.normalize, | ||
] | ||
) | ||
|
||
def __call__(self, item): | ||
return self.transform(item) | ||
|
||
@classmethod | ||
def from_config(cls, cfg=None): | ||
if cfg is None: | ||
cfg = OmegaConf.create() | ||
|
||
image_size = cfg.get("image_size", 224) | ||
|
||
mean = cfg.get("mean", None) | ||
std = cfg.get("std", None) | ||
|
||
return cls(image_size=image_size, mean=mean, std=std) |
5 changes: 5 additions & 0 deletions
5
apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
datasets: | ||
cc_sbu_align: | ||
data_type: images | ||
build_info: | ||
storage: /path/to/cc_sbu_align/ |
33 changes: 33 additions & 0 deletions
33
apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
model: | ||
arch: mini_gpt4 | ||
|
||
# vit encoder | ||
image_size: 224 | ||
drop_path_rate: 0 | ||
use_grad_checkpoint: False | ||
vit_precision: "fp16" | ||
freeze_vit: True | ||
freeze_qformer: True | ||
|
||
# Q-Former | ||
num_query_token: 32 | ||
|
||
# Vicuna | ||
llama_model: "/home/abhishek/vicuna_weights/vicuna_weights/" | ||
|
||
# generation configs | ||
prompt: "" | ||
|
||
preprocess: | ||
vis_processor: | ||
train: | ||
name: "blip2_image_train" | ||
image_size: 224 | ||
eval: | ||
name: "blip2_image_eval" | ||
image_size: 224 | ||
text_processor: | ||
train: | ||
name: "blip_caption" | ||
eval: | ||
name: "blip_caption" |
25 changes: 25 additions & 0 deletions
25
apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
model: | ||
arch: mini_gpt4 | ||
model_type: pretrain_vicuna | ||
freeze_vit: True | ||
freeze_qformer: True | ||
max_txt_len: 160 | ||
end_sym: "###" | ||
low_resource: False | ||
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt" | ||
prompt_template: '###Human: {} ###Assistant: ' | ||
ckpt: '/home/abhishek/prerained_minigpt4_7b.pth' | ||
|
||
|
||
datasets: | ||
cc_sbu_align: | ||
vis_processor: | ||
train: | ||
name: "blip2_image_eval" | ||
image_size: 224 | ||
text_processor: | ||
train: | ||
name: "blip_caption" | ||
|
||
run: | ||
task: image_text_pretrain |
137 changes: 137 additions & 0 deletions
137
apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import datetime | ||
import functools | ||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import timm.models.hub as timm_hub | ||
|
||
|
||
def setup_for_distributed(is_master): | ||
""" | ||
This function disables printing when not in master process | ||
""" | ||
import builtins as __builtin__ | ||
|
||
builtin_print = __builtin__.print | ||
|
||
def print(*args, **kwargs): | ||
force = kwargs.pop("force", False) | ||
if is_master or force: | ||
builtin_print(*args, **kwargs) | ||
|
||
__builtin__.print = print | ||
|
||
|
||
def is_dist_avail_and_initialized(): | ||
if not dist.is_available(): | ||
return False | ||
if not dist.is_initialized(): | ||
return False | ||
return True | ||
|
||
|
||
def get_world_size(): | ||
if not is_dist_avail_and_initialized(): | ||
return 1 | ||
return dist.get_world_size() | ||
|
||
|
||
def get_rank(): | ||
if not is_dist_avail_and_initialized(): | ||
return 0 | ||
return dist.get_rank() | ||
|
||
|
||
def is_main_process(): | ||
return get_rank() == 0 | ||
|
||
|
||
def init_distributed_mode(args): | ||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: | ||
args.rank = int(os.environ["RANK"]) | ||
args.world_size = int(os.environ["WORLD_SIZE"]) | ||
args.gpu = int(os.environ["LOCAL_RANK"]) | ||
elif "SLURM_PROCID" in os.environ: | ||
args.rank = int(os.environ["SLURM_PROCID"]) | ||
args.gpu = args.rank % torch.cuda.device_count() | ||
else: | ||
print("Not using distributed mode") | ||
args.distributed = False | ||
return | ||
|
||
args.distributed = True | ||
|
||
torch.cuda.set_device(args.gpu) | ||
args.dist_backend = "nccl" | ||
print( | ||
"| distributed init (rank {}, world {}): {}".format( | ||
args.rank, args.world_size, args.dist_url | ||
), | ||
flush=True, | ||
) | ||
torch.distributed.init_process_group( | ||
backend=args.dist_backend, | ||
init_method=args.dist_url, | ||
world_size=args.world_size, | ||
rank=args.rank, | ||
timeout=datetime.timedelta( | ||
days=365 | ||
), # allow auto-downloading and de-compressing | ||
) | ||
torch.distributed.barrier() | ||
setup_for_distributed(args.rank == 0) | ||
|
||
|
||
def get_dist_info(): | ||
if torch.__version__ < "1.0": | ||
initialized = dist._initialized | ||
else: | ||
initialized = dist.is_initialized() | ||
if initialized: | ||
rank = dist.get_rank() | ||
world_size = dist.get_world_size() | ||
else: # non-distributed training | ||
rank = 0 | ||
world_size = 1 | ||
return rank, world_size | ||
|
||
|
||
def main_process(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
rank, _ = get_dist_info() | ||
if rank == 0: | ||
return func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
def download_cached_file(url, check_hash=True, progress=False): | ||
""" | ||
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. | ||
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. | ||
""" | ||
|
||
def get_cached_file_path(): | ||
# a hack to sync the file path across processes | ||
parts = torch.hub.urlparse(url) | ||
filename = os.path.basename(parts.path) | ||
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) | ||
|
||
return cached_file | ||
|
||
if is_main_process(): | ||
timm_hub.download_cached_file(url, check_hash, progress) | ||
|
||
if is_dist_avail_and_initialized(): | ||
dist.barrier() | ||
|
||
return get_cached_file_path() |
Oops, something went wrong.