-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
-- This is the first installment of MiniGPT4 in SHARK. Signed-off-by: Abhishek Varma <[email protected]>
- Loading branch information
1 parent
c471d17
commit aaefacf
Showing
17 changed files
with
3,923 additions
and
1 deletion.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
1,180 changes: 1,180 additions & 0 deletions
1,180
apps/language_models/src/pipelines/minigpt4_pipeline.py
Large diffs are not rendered by default.
Oops, something went wrong.
1,297 changes: 1,297 additions & 0 deletions
1,297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Large diffs are not rendered by default.
Oops, something went wrong.
68 changes: 68 additions & 0 deletions
68
apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
from omegaconf import OmegaConf | ||
from torchvision import transforms | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
|
||
class BaseProcessor: | ||
def __init__(self): | ||
self.transform = lambda x: x | ||
return | ||
|
||
def __call__(self, item): | ||
return self.transform(item) | ||
|
||
@classmethod | ||
def from_config(cls, cfg=None): | ||
return cls() | ||
|
||
def build(self, **kwargs): | ||
cfg = OmegaConf.create(kwargs) | ||
|
||
return self.from_config(cfg) | ||
|
||
|
||
class BlipImageBaseProcessor(BaseProcessor): | ||
def __init__(self, mean=None, std=None): | ||
if mean is None: | ||
mean = (0.48145466, 0.4578275, 0.40821073) | ||
if std is None: | ||
std = (0.26862954, 0.26130258, 0.27577711) | ||
|
||
self.normalize = transforms.Normalize(mean, std) | ||
|
||
|
||
class Blip2ImageEvalProcessor(BlipImageBaseProcessor): | ||
def __init__(self, image_size=224, mean=None, std=None): | ||
super().__init__(mean=mean, std=std) | ||
|
||
self.transform = transforms.Compose( | ||
[ | ||
transforms.Resize( | ||
(image_size, image_size), | ||
interpolation=InterpolationMode.BICUBIC, | ||
), | ||
transforms.ToTensor(), | ||
self.normalize, | ||
] | ||
) | ||
|
||
def __call__(self, item): | ||
return self.transform(item) | ||
|
||
@classmethod | ||
def from_config(cls, cfg=None): | ||
if cfg is None: | ||
cfg = OmegaConf.create() | ||
|
||
image_size = cfg.get("image_size", 224) | ||
|
||
mean = cfg.get("mean", None) | ||
std = cfg.get("std", None) | ||
|
||
return cls(image_size=image_size, mean=mean, std=std) |
5 changes: 5 additions & 0 deletions
5
apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
datasets: | ||
cc_sbu_align: | ||
data_type: images | ||
build_info: | ||
storage: /path/to/cc_sbu_align/ |
33 changes: 33 additions & 0 deletions
33
apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
model: | ||
arch: mini_gpt4 | ||
|
||
# vit encoder | ||
image_size: 224 | ||
drop_path_rate: 0 | ||
use_grad_checkpoint: False | ||
vit_precision: "fp16" | ||
freeze_vit: True | ||
freeze_qformer: True | ||
|
||
# Q-Former | ||
num_query_token: 32 | ||
|
||
# Vicuna | ||
llama_model: "lmsys/vicuna-7b-v1.3" | ||
|
||
# generation configs | ||
prompt: "" | ||
|
||
preprocess: | ||
vis_processor: | ||
train: | ||
name: "blip2_image_train" | ||
image_size: 224 | ||
eval: | ||
name: "blip2_image_eval" | ||
image_size: 224 | ||
text_processor: | ||
train: | ||
name: "blip_caption" | ||
eval: | ||
name: "blip_caption" |
25 changes: 25 additions & 0 deletions
25
apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
model: | ||
arch: mini_gpt4 | ||
model_type: pretrain_vicuna | ||
freeze_vit: True | ||
freeze_qformer: True | ||
max_txt_len: 160 | ||
end_sym: "###" | ||
low_resource: False | ||
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt" | ||
prompt_template: '###Human: {} ###Assistant: ' | ||
ckpt: 'prerained_minigpt4_7b.pth' | ||
|
||
|
||
datasets: | ||
cc_sbu_align: | ||
vis_processor: | ||
train: | ||
name: "blip2_image_eval" | ||
image_size: 224 | ||
text_processor: | ||
train: | ||
name: "blip_caption" | ||
|
||
run: | ||
task: image_text_pretrain |
53 changes: 53 additions & 0 deletions
53
apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import timm.models.hub as timm_hub | ||
|
||
|
||
def is_dist_avail_and_initialized(): | ||
if not dist.is_available(): | ||
return False | ||
if not dist.is_initialized(): | ||
return False | ||
return True | ||
|
||
|
||
def get_rank(): | ||
if not is_dist_avail_and_initialized(): | ||
return 0 | ||
return dist.get_rank() | ||
|
||
|
||
def is_main_process(): | ||
return get_rank() == 0 | ||
|
||
|
||
def download_cached_file(url, check_hash=True, progress=False): | ||
""" | ||
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. | ||
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. | ||
""" | ||
|
||
def get_cached_file_path(): | ||
# a hack to sync the file path across processes | ||
parts = torch.hub.urlparse(url) | ||
filename = os.path.basename(parts.path) | ||
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) | ||
|
||
return cached_file | ||
|
||
if is_main_process(): | ||
timm_hub.download_cached_file(url, check_hash, progress) | ||
|
||
if is_dist_avail_and_initialized(): | ||
dist.barrier() | ||
|
||
return get_cached_file_path() |
Oops, something went wrong.