diff --git a/apps/language_models/src/model_wrappers/minigpt4.py b/apps/language_models/src/model_wrappers/minigpt4.py
new file mode 100644
index 0000000000..6521a0a5fe
--- /dev/null
+++ b/apps/language_models/src/model_wrappers/minigpt4.py
@@ -0,0 +1,433 @@
+import torch
+import dataclasses
+from enum import auto, Enum
+from typing import List, Any
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+
+class LayerNorm(torch.nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class VisionModel(torch.nn.Module):
+ def __init__(self, ln_vision, visual_encoder):
+ super().__init__()
+ self.ln_vision = ln_vision
+ self.visual_encoder = visual_encoder
+
+ def forward(self, image):
+ image_embeds = self.ln_vision(self.visual_encoder(image))
+ return image_embeds
+
+
+class QformerBertModel(torch.nn.Module):
+ def __init__(self, qformer_bert):
+ super().__init__()
+ self.qformer_bert = qformer_bert
+
+ def forward(self, query_tokens, image_embeds, image_atts):
+ query_output = self.qformer_bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ return query_output.last_hidden_state
+
+
+class FirstLlamaModel(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ print("SHARK: Loading LLAMA Done")
+
+ def forward(self, inputs_embeds, position_ids, attention_mask):
+ print("************************************")
+ print(
+ "inputs_embeds: ",
+ inputs_embeds.shape,
+ " dtype: ",
+ inputs_embeds.dtype,
+ )
+ print(
+ "position_ids: ",
+ position_ids.shape,
+ " dtype: ",
+ position_ids.dtype,
+ )
+ print(
+ "attention_mask: ",
+ attention_mask.shape,
+ " dtype: ",
+ attention_mask.dtype,
+ )
+ print("************************************")
+ config = {
+ "inputs_embeds": inputs_embeds,
+ "position_ids": position_ids,
+ "past_key_values": None,
+ "use_cache": True,
+ "attention_mask": attention_mask,
+ }
+ output = self.model(
+ **config,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+ return_vals = []
+ return_vals.append(output.logits)
+ temp_past_key_values = output.past_key_values
+ for item in temp_past_key_values:
+ return_vals.append(item[0])
+ return_vals.append(item[1])
+ return tuple(return_vals)
+
+
+class SecondLlamaModel(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ print("SHARK: Loading LLAMA Done")
+
+ def forward(
+ self,
+ input_ids,
+ position_ids,
+ attention_mask,
+ i1,
+ i2,
+ i3,
+ i4,
+ i5,
+ i6,
+ i7,
+ i8,
+ i9,
+ i10,
+ i11,
+ i12,
+ i13,
+ i14,
+ i15,
+ i16,
+ i17,
+ i18,
+ i19,
+ i20,
+ i21,
+ i22,
+ i23,
+ i24,
+ i25,
+ i26,
+ i27,
+ i28,
+ i29,
+ i30,
+ i31,
+ i32,
+ i33,
+ i34,
+ i35,
+ i36,
+ i37,
+ i38,
+ i39,
+ i40,
+ i41,
+ i42,
+ i43,
+ i44,
+ i45,
+ i46,
+ i47,
+ i48,
+ i49,
+ i50,
+ i51,
+ i52,
+ i53,
+ i54,
+ i55,
+ i56,
+ i57,
+ i58,
+ i59,
+ i60,
+ i61,
+ i62,
+ i63,
+ i64,
+ ):
+ print("************************************")
+ print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
+ print(
+ "position_ids: ",
+ position_ids.shape,
+ " dtype: ",
+ position_ids.dtype,
+ )
+ print(
+ "attention_mask: ",
+ attention_mask.shape,
+ " dtype: ",
+ attention_mask.dtype,
+ )
+ print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
+ print("past_key_values dtype: ", i1.dtype)
+ print("************************************")
+ config = {
+ "input_ids": input_ids,
+ "position_ids": position_ids,
+ "past_key_values": (
+ (i1, i2),
+ (
+ i3,
+ i4,
+ ),
+ (
+ i5,
+ i6,
+ ),
+ (
+ i7,
+ i8,
+ ),
+ (
+ i9,
+ i10,
+ ),
+ (
+ i11,
+ i12,
+ ),
+ (
+ i13,
+ i14,
+ ),
+ (
+ i15,
+ i16,
+ ),
+ (
+ i17,
+ i18,
+ ),
+ (
+ i19,
+ i20,
+ ),
+ (
+ i21,
+ i22,
+ ),
+ (
+ i23,
+ i24,
+ ),
+ (
+ i25,
+ i26,
+ ),
+ (
+ i27,
+ i28,
+ ),
+ (
+ i29,
+ i30,
+ ),
+ (
+ i31,
+ i32,
+ ),
+ (
+ i33,
+ i34,
+ ),
+ (
+ i35,
+ i36,
+ ),
+ (
+ i37,
+ i38,
+ ),
+ (
+ i39,
+ i40,
+ ),
+ (
+ i41,
+ i42,
+ ),
+ (
+ i43,
+ i44,
+ ),
+ (
+ i45,
+ i46,
+ ),
+ (
+ i47,
+ i48,
+ ),
+ (
+ i49,
+ i50,
+ ),
+ (
+ i51,
+ i52,
+ ),
+ (
+ i53,
+ i54,
+ ),
+ (
+ i55,
+ i56,
+ ),
+ (
+ i57,
+ i58,
+ ),
+ (
+ i59,
+ i60,
+ ),
+ (
+ i61,
+ i62,
+ ),
+ (
+ i63,
+ i64,
+ ),
+ ),
+ "use_cache": True,
+ "attention_mask": attention_mask,
+ }
+ output = self.model(
+ **config,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+ return_vals = []
+ return_vals.append(output.logits)
+ temp_past_key_values = output.past_key_values
+ for item in temp_past_key_values:
+ return_vals.append(item[0])
+ return_vals.append(item[1])
+ return tuple(return_vals)
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+
+ SINGLE = auto()
+ TWO = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+
+ skip_next: bool = False
+ conv_id: Any = None
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ conv_id=self.conv_id,
+ )
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ "conv_id": self.conv_id,
+ }
+
+
+class StoppingCriteriaSub(StoppingCriteria):
+ def __init__(self, stops=[], encounters=1):
+ super().__init__()
+ self.stops = stops
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all((stop == input_ids[0][-len(stop) :])).item():
+ return True
+
+ return False
+
+
+CONV_VISION = Conversation(
+ system="Give the following image: ImageContent. "
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
diff --git a/apps/language_models/src/pipelines/minigpt4_pipeline.py b/apps/language_models/src/pipelines/minigpt4_pipeline.py
new file mode 100644
index 0000000000..43615c3a6d
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_pipeline.py
@@ -0,0 +1,1154 @@
+from apps.language_models.src.model_wrappers.minigpt4 import (
+ LayerNorm,
+ VisionModel,
+ QformerBertModel,
+ FirstLlamaModel,
+ SecondLlamaModel,
+ StoppingCriteriaSub,
+ CONV_VISION,
+)
+from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
+from apps.language_models.utils import (
+ get_vmfb_from_path,
+ get_vmfb_from_config,
+)
+from omegaconf import OmegaConf
+from pathlib import Path
+from shark.shark_downloader import download_public_file
+from transformers import LlamaTokenizer, LlamaForCausalLM
+from transformers import StoppingCriteriaList
+from transformers.generation import GenerationConfig, LogitsProcessorList
+
+import re
+import torch
+import os
+from PIL import Image
+import sys
+
+# SHARK dependencies
+from shark.shark_compile import (
+ shark_compile_through_fx,
+)
+import random
+import contextlib
+from transformers import BertTokenizer
+from transformers import LlamaTokenizer, LlamaForCausalLM
+from transformers.generation import GenerationConfig, LogitsProcessorList
+import copy
+import tempfile
+
+# QFormer, eva_vit, blip_processor, dist_utils
+from apps.language_models.src.pipelines.minigpt4_utils.Qformer import (
+ BertConfig,
+ BertLMHeadModel,
+)
+from apps.language_models.src.pipelines.minigpt4_utils.dist_utils import (
+ download_cached_file,
+)
+from apps.language_models.src.pipelines.minigpt4_utils.eva_vit import (
+ create_eva_vit_g,
+)
+from apps.language_models.src.pipelines.minigpt4_utils.blip_processors import (
+ Blip2ImageEvalProcessor,
+)
+
+import argparse
+
+parser = argparse.ArgumentParser(
+ prog="MiniGPT4 runner",
+ description="runs MiniGPT4",
+)
+
+parser.add_argument(
+ "--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
+)
+parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
+parser.add_argument(
+ "--vision_model_vmfb_path",
+ default=None,
+ help="path to vision model's vmfb",
+)
+parser.add_argument(
+ "--qformer_vmfb_path",
+ default=None,
+ help="path to qformer model's vmfb",
+)
+parser.add_argument(
+ "--image_path",
+ type=str,
+ default="",
+ help="path to the input image",
+)
+parser.add_argument(
+ "--load_mlir_from_shark_tank",
+ default=False,
+ action=argparse.BooleanOptionalAction,
+ help="download precompile mlir from shark tank",
+)
+parser.add_argument(
+ "--cli",
+ default=True,
+ action=argparse.BooleanOptionalAction,
+ help="Run model in cli mode",
+)
+parser.add_argument(
+ "--compile",
+ default=False,
+ action=argparse.BooleanOptionalAction,
+ help="Compile all models",
+)
+parser.add_argument(
+ "--max_length",
+ type=int,
+ default=2000,
+ help="Max length of the entire conversation",
+)
+parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=300,
+ help="Maximum no. of new tokens that can be generated for a query",
+)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def is_url(input_url):
+ """
+ Check if an input string is a url. look for http(s):// and ignoring the case
+ """
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
+ return is_url
+
+
+class MiniGPT4BaseModel(torch.nn.Module):
+ @classmethod
+ def from_config(cls, cfg):
+ vit_model = cfg.get("vit_model", "eva_clip_g")
+ q_former_model = cfg.get(
+ "q_former_model",
+ "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+ )
+ img_size = cfg.get("image_size")
+ num_query_token = cfg.get("num_query_token")
+ llama_model = cfg.get("llama_model")
+
+ drop_path_rate = cfg.get("drop_path_rate", 0)
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+ vit_precision = cfg.get("vit_precision", "fp16")
+ freeze_vit = cfg.get("freeze_vit", True)
+ freeze_qformer = cfg.get("freeze_qformer", True)
+ low_resource = cfg.get("low_resource", False)
+ device_8bit = cfg.get("device_8bit", 0)
+
+ prompt_path = cfg.get("prompt_path", "")
+ prompt_template = cfg.get("prompt_template", "")
+ max_txt_len = cfg.get("max_txt_len", 32)
+ end_sym = cfg.get("end_sym", "\n")
+
+ model = cls(
+ vit_model=vit_model,
+ q_former_model=q_former_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ freeze_qformer=freeze_qformer,
+ num_query_token=num_query_token,
+ llama_model=llama_model,
+ prompt_path=prompt_path,
+ prompt_template=prompt_template,
+ max_txt_len=max_txt_len,
+ end_sym=end_sym,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ )
+
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
+ if ckpt_path:
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ model.load_state_dict(ckpt["model"], strict=False)
+
+ return model
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "pretrain_vicuna": "configs/minigpt4.yaml",
+ }
+
+ def maybe_autocast(self, dtype=torch.float32):
+ # if on cpu, don't use autocast
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+ # enable_autocast = self.device != torch.device("cpu")
+ enable_autocast = True
+
+ if enable_autocast:
+ return torch.cuda.amp.autocast(dtype=dtype)
+ else:
+ return contextlib.nullcontext()
+
+ def init_tokenizer(cls):
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+ return tokenizer
+
+ def init_vision_encoder(
+ self,
+ model_name,
+ img_size,
+ drop_path_rate,
+ use_grad_checkpoint,
+ precision,
+ ):
+ assert (
+ model_name == "eva_clip_g"
+ ), "vit model must be eva_clip_g for current version of MiniGPT-4"
+ visual_encoder = create_eva_vit_g(
+ img_size, drop_path_rate, use_grad_checkpoint, precision
+ )
+
+ ln_vision = LayerNorm(visual_encoder.num_features)
+ return visual_encoder, ln_vision
+
+ def init_Qformer(
+ cls, num_query_token, vision_width, cross_attention_freq=2
+ ):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = torch.nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(
+ mean=0.0, std=encoder_config.initializer_range
+ )
+ return Qformer, query_tokens
+
+ def load_from_pretrained(self, url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+
+ self.load_state_dict(state_dict, strict=False)
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+ img_size=224,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ freeze_qformer=True,
+ num_query_token=32,
+ llama_model="",
+ prompt_path="",
+ prompt_template="",
+ max_txt_len=32,
+ end_sym="\n",
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+ ):
+ super().__init__()
+ self.tokenizer = self.init_tokenizer()
+ self.low_resource = low_resource
+
+ print("Loading VIT")
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+ vit_model,
+ img_size,
+ drop_path_rate,
+ use_grad_checkpoint,
+ vit_precision,
+ )
+ if freeze_vit:
+ for _, param in self.visual_encoder.named_parameters():
+ param.requires_grad = False
+ self.visual_encoder = self.visual_encoder.eval()
+ self.visual_encoder.train = disabled_train
+ for _, param in self.ln_vision.named_parameters():
+ param.requires_grad = False
+ self.ln_vision = self.ln_vision.eval()
+ self.ln_vision.train = disabled_train
+ # logging.info("freeze vision encoder")
+ print("Loading VIT Done")
+
+ print("Loading Q-Former")
+ self.Qformer, self.query_tokens = self.init_Qformer(
+ num_query_token, self.visual_encoder.num_features
+ )
+ self.Qformer.cls = None
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ self.load_from_pretrained(url_or_filename=q_former_model)
+
+ if freeze_qformer:
+ for _, param in self.Qformer.named_parameters():
+ param.requires_grad = False
+ self.Qformer = self.Qformer.eval()
+ self.Qformer.train = disabled_train
+ self.query_tokens.requires_grad = False
+ # logging.info("freeze Qformer")
+ print("Loading Q-Former Done")
+
+ print(f"Loading Llama model from {llama_model}")
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(
+ llama_model, use_fast=False
+ )
+ # self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
+
+ if self.low_resource:
+ self.llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model,
+ torch_dtype=torch.float16,
+ load_in_8bit=True,
+ device_map={"": device_8bit},
+ )
+ else:
+ self.llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model,
+ torch_dtype=torch.float32,
+ )
+
+ for _, param in self.llama_model.named_parameters():
+ param.requires_grad = False
+ print("Loading Llama Done")
+
+ self.llama_proj = torch.nn.Linear(
+ self.Qformer.config.hidden_size,
+ self.llama_model.config.hidden_size,
+ )
+ self.max_txt_len = max_txt_len
+ self.end_sym = end_sym
+
+ if prompt_path:
+ with open(prompt_path, "r") as f:
+ raw_prompts = f.read().splitlines()
+ filted_prompts = [
+ raw_prompt
+ for raw_prompt in raw_prompts
+ if "" in raw_prompt
+ ]
+ self.prompt_list = [
+ prompt_template.format(p) for p in filted_prompts
+ ]
+ print("Load {} training prompts".format(len(self.prompt_list)))
+ print(
+ "Prompt Example \n{}".format(random.choice(self.prompt_list))
+ )
+ else:
+ self.prompt_list = []
+
+
+class MiniGPT4(SharkLLMBase):
+ def __init__(
+ self,
+ model_name,
+ hf_model_path=None,
+ max_new_tokens=300,
+ device="cuda",
+ precision="fp16",
+ _compile=False,
+ vision_model_vmfb_path=Path("vision_model_fp16_cuda.vmfb"),
+ qformer_vmfb_path=Path("qformer_fp32_cuda.vmfb"),
+ ) -> None:
+ self.model_name = model_name
+ self.shark_model = None
+ super().__init__(model_name, hf_model_path, max_new_tokens)
+ self.download_dependencies()
+ self.device = device
+ self.precision = precision
+ self._compile = _compile
+
+ self.vision_model_vmfb_path = vision_model_vmfb_path
+ self.qformer_vmfb_path = qformer_vmfb_path
+ self.first_llama_vmfb_path = None
+ self.second_llama_vmfb_path = None
+
+ print("Initializing Chat")
+ config = OmegaConf.load(
+ "apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml"
+ )
+ model_config = OmegaConf.create()
+ model_config = OmegaConf.merge(
+ model_config,
+ OmegaConf.load(
+ "apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml"
+ ),
+ {"model": config["model"]},
+ )
+ model_config = model_config["model"]
+ model_config.device_8bit = 0
+ model = MiniGPT4BaseModel.from_config(model_config).to("cpu")
+ datasets = config.get("datasets", None)
+ dataset_config = OmegaConf.create()
+ for dataset_name in datasets:
+ dataset_config_path = "apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml"
+ dataset_config = OmegaConf.merge(
+ dataset_config,
+ OmegaConf.load(dataset_config_path),
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
+ )
+ dataset_config = dataset_config["datasets"]
+ vis_processor_cfg = dataset_config.cc_sbu_align.vis_processor.train
+ vis_processor = Blip2ImageEvalProcessor.from_config(vis_processor_cfg)
+ print("Initialization complete")
+
+ self.model = model
+ self.vis_processor = vis_processor
+ stop_words_ids = [
+ torch.tensor([835]).to("cpu"),
+ torch.tensor([2277, 29937]).to("cpu"),
+ ] # '###' can be encoded in two different ways.
+ self.stopping_criteria = StoppingCriteriaList(
+ [StoppingCriteriaSub(stops=stop_words_ids)]
+ )
+
+ self.first_llama = None
+ self.second_llama = None
+
+ def download_dependencies(self):
+ pretrained_file = "prerained_minigpt4_7b.pth"
+ pretrained_file_url = f"gs://shark_tank/MiniGPT4/{pretrained_file}"
+ if not os.path.isfile(pretrained_file):
+ download_public_file(
+ pretrained_file_url,
+ Path("prerained_minigpt4_7b.pth").absolute(),
+ single_file=True,
+ )
+
+ if os.path.isfile(pretrained_file):
+ print(f"File downloaded successfully: {pretrained_file}")
+ else:
+ print(f"Error downloading {pretrained_file}")
+ sys.exit()
+
+ # Currently we're compiling VisionModel for fp32/cuda.
+ def compile_vision_model(self):
+ if not self._compile:
+ vmfb = get_vmfb_from_path(
+ self.vision_model_vmfb_path, self.device, "tm_tensor"
+ )
+ if vmfb is not None:
+ return vmfb
+ else:
+ vmfb = get_vmfb_from_config(
+ self.model_name,
+ "vision_model",
+ self.precision,
+ self.device,
+ self.vision_model_vmfb_path,
+ )
+ if vmfb is not None:
+ return vmfb
+
+ visionModel = VisionModel(
+ self.model.ln_vision, self.model.visual_encoder
+ )
+ extended_model_name = f"vision_model_{self.precision}_{self.device}"
+ print(f"Going to compile {extended_model_name}")
+ # Inputs for VisionModel.
+ inputs = [torch.randint(3, (1, 3, 224, 224), dtype=torch.float32)]
+ is_f16 = False
+ if self.precision == "fp16":
+ is_f16 = True
+ shark_visionModel, _ = shark_compile_through_fx(
+ visionModel,
+ inputs,
+ extended_model_name=extended_model_name,
+ is_f16=is_f16,
+ f16_input_mask=None,
+ save_dir=tempfile.gettempdir(),
+ debug=False,
+ generate_or_load_vmfb=True,
+ extra_args=[],
+ device=self.device,
+ mlir_dialect="tm_tensor",
+ )
+ print(f"Generated {extended_model_name}.vmfb")
+ return shark_visionModel
+
+ def compile_qformer_model(self):
+ if not self._compile:
+ vmfb = get_vmfb_from_path(
+ self.qformer_vmfb_path, self.device, "tm_tensor"
+ )
+ if vmfb is not None:
+ return vmfb
+ else:
+ vmfb = get_vmfb_from_config(
+ self.model_name,
+ "qformer",
+ "fp32",
+ self.device,
+ self.qformer_vmfb_path,
+ )
+ if vmfb is not None:
+ return vmfb
+
+ qformerBertModel = QformerBertModel(self.model.Qformer.bert)
+ extended_model_name = f"qformer_fp32_{self.device}"
+ print(f"Going to compile {extended_model_name}")
+ # Inputs for QFormer.
+ inputs = [
+ torch.randint(3, (1, 32, 768), dtype=torch.float32),
+ torch.randint(3, (1, 257, 1408), dtype=torch.float32),
+ torch.randint(3, (1, 257), dtype=torch.int64),
+ ]
+ is_f16 = False
+ f16_input_mask = []
+ shark_QformerBertModel, _ = shark_compile_through_fx(
+ qformerBertModel,
+ inputs,
+ extended_model_name=extended_model_name,
+ is_f16=is_f16,
+ f16_input_mask=f16_input_mask,
+ save_dir=tempfile.gettempdir(),
+ debug=False,
+ generate_or_load_vmfb=True,
+ extra_args=[],
+ device=self.device,
+ mlir_dialect="tm_tensor",
+ )
+ print(f"Generated {extended_model_name}.vmfb")
+ return shark_QformerBertModel
+
+ def compile_first_llama(self, padding):
+ self.first_llama_vmfb_path = Path(
+ f"first_llama_{self.precision}_{self.device}_{padding}.vmfb"
+ )
+ if not self._compile:
+ vmfb = get_vmfb_from_path(
+ self.first_llama_vmfb_path, self.device, "tm_tensor"
+ )
+ if vmfb is not None:
+ self.first_llama = vmfb
+ return vmfb
+ else:
+ vmfb = get_vmfb_from_config(
+ self.model_name,
+ "first_llama",
+ self.precision,
+ self.device,
+ self.first_llama_vmfb_path,
+ padding,
+ )
+ if vmfb is not None:
+ self.first_llama = vmfb
+ return vmfb
+
+ firstLlamaModel = FirstLlamaModel(
+ copy.deepcopy(self.model.llama_model)
+ )
+ extended_model_name = (
+ f"first_llama_{self.precision}_{self.device}_{padding}"
+ )
+ print(f"Going to compile {extended_model_name}")
+ # Inputs for FirstLlama.
+ inputs_embeds = torch.ones((1, padding, 4096), dtype=torch.float32)
+ position_ids = torch.ones((1, padding), dtype=torch.int64)
+ attention_mask = torch.ones((1, padding), dtype=torch.int32)
+ inputs = [inputs_embeds, position_ids, attention_mask]
+ is_f16 = False
+ f16_input_mask = []
+ if self.precision == "fp16":
+ is_f16 = True
+ f16_input_mask = [True, False, False]
+ shark_firstLlamaModel, _ = shark_compile_through_fx(
+ firstLlamaModel,
+ inputs,
+ extended_model_name=extended_model_name,
+ is_f16=is_f16,
+ f16_input_mask=f16_input_mask,
+ save_dir=tempfile.gettempdir(),
+ debug=False,
+ generate_or_load_vmfb=True,
+ extra_args=[],
+ device=self.device,
+ mlir_dialect="tm_tensor",
+ )
+ print(f"Generated {extended_model_name}.vmfb")
+ self.first_llama = shark_firstLlamaModel
+ return shark_firstLlamaModel
+
+ def compile_second_llama(self, padding):
+ self.second_llama_vmfb_path = Path(
+ f"second_llama_{self.precision}_{self.device}_{padding}.vmfb"
+ )
+ if not self._compile:
+ vmfb = get_vmfb_from_path(
+ self.second_llama_vmfb_path, self.device, "tm_tensor"
+ )
+ if vmfb is not None:
+ self.second_llama = vmfb
+ return vmfb
+ else:
+ vmfb = get_vmfb_from_config(
+ self.model_name,
+ "second_llama",
+ self.precision,
+ self.device,
+ self.second_llama_vmfb_path,
+ padding,
+ )
+ if vmfb is not None:
+ self.second_llama = vmfb
+ return vmfb
+
+ secondLlamaModel = SecondLlamaModel(
+ copy.deepcopy(self.model.llama_model)
+ )
+ extended_model_name = (
+ f"second_llama_{self.precision}_{self.device}_{padding}"
+ )
+ print(f"Going to compile {extended_model_name}")
+ # Inputs for SecondLlama.
+ input_ids = torch.zeros((1, 1), dtype=torch.int64)
+ position_ids = torch.zeros((1, 1), dtype=torch.int64)
+ attention_mask = torch.zeros((1, padding + 1), dtype=torch.int32)
+ past_key_value = []
+ for i in range(64):
+ past_key_value.append(
+ torch.zeros(1, 32, padding, 128, dtype=torch.float32)
+ )
+ inputs = [input_ids, position_ids, attention_mask, *past_key_value]
+ is_f16 = False
+ f16_input_mask = []
+ if self.precision == "fp16":
+ is_f16 = True
+ f16_input_mask = [False, False, False]
+ for i in past_key_value:
+ f16_input_mask.append(True)
+
+ shark_secondLlamaModel, _ = shark_compile_through_fx(
+ secondLlamaModel,
+ inputs,
+ extended_model_name=extended_model_name,
+ is_f16=is_f16,
+ f16_input_mask=f16_input_mask,
+ save_dir=tempfile.gettempdir(),
+ debug=False,
+ generate_or_load_vmfb=True,
+ extra_args=[],
+ device=self.device,
+ mlir_dialect="tm_tensor",
+ )
+ print(f"Generated {extended_model_name}.vmfb")
+ self.second_llama = shark_secondLlamaModel
+ return shark_secondLlamaModel
+
+ # Not yet sure why to use this.
+ def compile(self):
+ pass
+
+ # Going to use `answer` instead.
+ def generate(self, prompt):
+ pass
+
+ # Might use within `answer`, if needed.
+ def generate_new_token(self, params):
+ pass
+
+ # Not needed yet because MiniGPT4BaseModel already loads this - will revisit later,
+ # if required.
+ def get_tokenizer(self):
+ pass
+
+ # DumDum func - doing the intended stuff already at MiniGPT4BaseModel,
+ # i.e load llama, etc.
+ def get_src_model(self):
+ pass
+
+ def ask(self, text, conv):
+ if (
+ len(conv.messages) > 0
+ and conv.messages[-1][0] == conv.roles[0]
+ and conv.messages[-1][1][-6:] == ""
+ ): # last message is image.
+ conv.messages[-1][1] = " ".join([conv.messages[-1][1], text])
+ else:
+ conv.append_message(conv.roles[0], text)
+
+ def answer(
+ self,
+ conv,
+ img_list,
+ max_new_tokens=300,
+ num_beams=1,
+ min_length=1,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ length_penalty=1,
+ temperature=1.0,
+ max_length=2000,
+ ):
+ conv.append_message(conv.roles[1], None)
+ embs = self.get_context_emb(
+ conv, img_list, max_length - max_new_tokens
+ )
+ padding = max_length - max_new_tokens
+
+ current_max_len = embs.shape[1] + max_new_tokens
+
+ if current_max_len - max_length > 0:
+ print(
+ "Warning: The number of tokens in current conversation exceeds the max length. "
+ "The model will not see the contexts outside the range."
+ )
+ begin_idx = max(0, current_max_len - max_length)
+
+ embs = embs[:, begin_idx:]
+
+ #########################################################################################################
+
+ generation_config = GenerationConfig.from_model_config(
+ self.model.llama_model.config
+ )
+ kwargs = {
+ "inputs_embeds": embs,
+ "max_new_tokens": max_new_tokens,
+ "num_beams": num_beams,
+ "do_sample": True,
+ "min_length": min_length,
+ "top_p": top_p,
+ "repetition_penalty": repetition_penalty,
+ "length_penalty": length_penalty,
+ "temperature": temperature,
+ }
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ logits_processor = LogitsProcessorList()
+ stopping_criteria = self.stopping_criteria
+ inputs = None
+ (
+ inputs_tensor,
+ model_input_name,
+ model_kwargs,
+ ) = self.model.llama_model._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs[
+ "output_hidden_states"
+ ] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
+ generation_config.pad_token_id = (
+ self.model.llama_tokenizer.pad_token_id
+ )
+ pad_token_id = generation_config.pad_token_id
+ embs_for_pad_token_id = self.model.llama_model.model.embed_tokens(
+ torch.tensor([pad_token_id])
+ )
+ model_kwargs["attention_mask"] = torch.logical_not(
+ torch.tensor(
+ [
+ torch.all(
+ torch.eq(inputs_tensor[:, d, :], embs_for_pad_token_id)
+ ).int()
+ for d in range(inputs_tensor.shape[1])
+ ]
+ ).unsqueeze(0)
+ ).int()
+ attention_meta_data = (model_kwargs["attention_mask"][0] == 0).nonzero(
+ as_tuple=True
+ )[0]
+ first_zero = attention_meta_data[0].item()
+ last_zero = attention_meta_data[-1].item()
+ input_ids = (
+ inputs_tensor
+ if model_input_name == "input_ids"
+ else model_kwargs.pop("input_ids")
+ )
+ input_ids_seq_length = input_ids.shape[-1]
+ generation_config.max_length = (
+ generation_config.max_new_tokens + input_ids_seq_length
+ )
+ logits_warper = self.model.llama_model._get_logits_warper(
+ generation_config
+ )
+ (
+ input_ids,
+ model_kwargs,
+ ) = self.model.llama_model._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=False,
+ **model_kwargs,
+ )
+ # DOUBT: stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
+ logits_warper = (
+ logits_warper
+ if logits_warper is not None
+ else LogitsProcessorList()
+ )
+ pad_token_id = generation_config.pad_token_id
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ eos_token_id_tensor = (
+ torch.tensor(eos_token_id).to(input_ids.device)
+ if eos_token_id is not None
+ else None
+ )
+ scores = None
+
+ # keep track of which sequences are already finished
+ unfinished_sequences = torch.ones(
+ input_ids.shape[0], dtype=torch.long, device=input_ids.device
+ )
+ i = 0
+ timesRan = 0
+ is_fp16 = True
+ llama_list = []
+ isPyTorchVariant = False
+ while True:
+ print("****** Iteration %d ******" % (i))
+ # prepare model inputs
+ model_inputs = (
+ self.model.llama_model.prepare_inputs_for_generation(
+ input_ids, **model_kwargs
+ )
+ )
+
+ # forward pass to get next token
+ if i == 0:
+ shark_inputs = []
+ if is_fp16:
+ model_inputs["inputs_embeds"] = model_inputs[
+ "inputs_embeds"
+ ].to(torch.float16)
+ shark_inputs.append(model_inputs["inputs_embeds"].detach())
+ shark_inputs.append(model_inputs["position_ids"].detach())
+ shark_inputs.append(model_inputs["attention_mask"].detach())
+
+ if self.first_llama is None:
+ self.compile_first_llama(padding)
+ outputs_shark = self.first_llama("forward", shark_inputs)
+ outputs = []
+ for out_shark in outputs_shark:
+ outputs.append(torch.from_numpy(out_shark))
+ del outputs_shark
+ else:
+ shark_inputs = []
+ shark_inputs.append(model_inputs["input_ids"].detach())
+ shark_inputs.append(model_inputs["position_ids"].detach())
+ shark_inputs.append(model_inputs["attention_mask"].detach())
+ for pkv in list(model_inputs["past_key_values"]):
+ shark_inputs.append(pkv.detach())
+ if self.second_llama is None:
+ self.compile_second_llama(padding)
+ outputs_shark = self.second_llama("forward", shark_inputs)
+ outputs = []
+ for out_shark in outputs_shark:
+ outputs.append(torch.from_numpy(out_shark))
+ del outputs_shark
+
+ outputs_logits = outputs[0]
+ next_token_logits = outputs_logits[:, -1, :]
+ if is_fp16:
+ next_token_logits = next_token_logits.to(torch.float32)
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError(
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+ )
+ next_tokens = (
+ next_tokens * unfinished_sequences
+ + pad_token_id * (1 - unfinished_sequences)
+ )
+
+ # update generated ids, model inputs, and length for next step
+ outputs_for_update_func = {}
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = (
+ self.model.llama_model._update_model_kwargs_for_generation(
+ outputs_for_update_func,
+ model_kwargs,
+ is_encoder_decoder=False,
+ )
+ )
+ model_kwargs["past_key_values"] = outputs[1:]
+ if timesRan >= 1:
+ tmp_attention_mask = torch.cat(
+ (
+ model_kwargs["attention_mask"][:, :first_zero],
+ model_kwargs["attention_mask"][:, first_zero + 1 :],
+ ),
+ dim=1,
+ )
+ model_kwargs["attention_mask"] = tmp_attention_mask
+ pkv_list = []
+ for pkv_pair_tuple in model_kwargs["past_key_values"]:
+ x = torch.cat(
+ (
+ pkv_pair_tuple[:, :, :first_zero, :],
+ pkv_pair_tuple[:, :, first_zero + 1 :, :],
+ ),
+ dim=2,
+ )
+ if is_fp16:
+ x = x.to(torch.float16)
+ pkv_list.append(x)
+ model_kwargs["past_key_values"] = tuple(pkv_list)
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id_tensor is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1)
+ .ne(eos_token_id_tensor.unsqueeze(1))
+ .prod(dim=0)
+ )
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(
+ input_ids, scores
+ ):
+ break
+
+ i = i + 1
+ timesRan += 1
+ llama_list.clear()
+ output_token = input_ids[0]
+
+ if (
+ output_token[0] == 0
+ ): # the model might output a unknow token at the beginning. remove it
+ output_token = output_token[1:]
+ if (
+ output_token[0] == 1
+ ): # some users find that there is a start token at the beginning. remove it
+ output_token = output_token[1:]
+ output_text = self.model.llama_tokenizer.decode(
+ output_token, add_special_tokens=False
+ )
+ output_text = output_text.split("###")[0] # remove the stop sign '###'
+ output_text = output_text.split("Assistant:")[-1].strip()
+ conv.messages[-1][1] = output_text
+ return output_text, output_token.cpu().numpy()
+
+ def upload_img(self, image, conv, img_list):
+ if isinstance(image, str): # is a image path
+ raw_image = Image.open(image).convert("RGB")
+ image = self.vis_processor(raw_image).unsqueeze(0).to("cpu")
+ elif isinstance(image, Image.Image):
+ raw_image = image
+ image = self.vis_processor(raw_image).unsqueeze(0).to("cpu")
+ elif isinstance(image, torch.Tensor):
+ if len(image.shape) == 3:
+ image = image.unsqueeze(0)
+ image = image.to("cpu")
+
+ device = image.device
+ if self.model.low_resource:
+ self.model.vit_to_cpu()
+ image = image.to("cpu")
+
+ with self.model.maybe_autocast():
+ shark_visionModel = self.compile_vision_model()
+ if self.precision == "fp16":
+ image = image.to(torch.float16)
+ image_embeds = shark_visionModel("forward", (image,))
+ # image_embeds = shark_visionModel.forward(image)
+ image_embeds = torch.from_numpy(image_embeds)
+ image_embeds = image_embeds.to(device).to(torch.float32)
+ image_atts = torch.ones(
+ image_embeds.size()[:-1], dtype=torch.long
+ ).to(device)
+
+ query_tokens = self.model.query_tokens.expand(
+ image_embeds.shape[0], -1, -1
+ ).to(device)
+ # if self.precision == "fp16":
+ # query_tokens = query_tokens.to(torch.float16)
+ shark_QformerBertModel = self.compile_qformer_model()
+ query_output = shark_QformerBertModel(
+ "forward",
+ (
+ query_tokens,
+ image_embeds,
+ image_atts,
+ ),
+ )
+ # query_output = shark_QformerBertModel.forward(query_tokens, image_embeds, image_atts)
+ query_output = torch.from_numpy(query_output)
+
+ inputs_llama = self.model.llama_proj(query_output)
+ image_emb = inputs_llama
+ img_list.append(image_emb)
+ conv.append_message(conv.roles[0], "")
+ msg = "Received."
+ return msg
+
+ # """
+ def get_context_emb(self, conv, img_list, max_allowed_tokens=200):
+ self.model.llama_tokenizer.padding_side = "left"
+ prompt = conv.get_prompt()
+ prompt_segs = prompt.split("")
+ assert (
+ len(prompt_segs) == len(img_list) + 1
+ ), "Unmatched numbers of image placeholders and images."
+ prompt_segs_pre = prompt_segs[:-1]
+ seg_tokens_pre = []
+ for i, seg in enumerate(prompt_segs_pre):
+ # only add bos to the first seg
+ if i==0:
+ add_special_tokens = True
+ else:
+ add_special_tokens = False
+ stp = self.model.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=add_special_tokens
+ ).to("cpu").input_ids
+ seg_tokens_pre.append(stp)
+ # seg_tokens_pre = [
+ # self.model.llama_tokenizer(
+ # seg, return_tensors="pt", add_special_tokens=i == 0
+ # )
+ # .to("cpu")
+ # .input_ids
+ # for i, seg in enumerate(prompt_segs_pre)
+ # ]
+ print("Before :-\nLlama model pad token : ",self.model.llama_model.config.pad_token_id)
+ print("Llama tokenizer pad token : ",self.model.llama_tokenizer.pad_token_id)
+ self.model.llama_model.config.pad_token_id = (
+ self.model.llama_tokenizer.pad_token_id
+ )
+ print("After :-\nLlama model pad token : ",self.model.llama_model.config.pad_token_id)
+ print("Llama tokenizer pad token : ",self.model.llama_tokenizer.pad_token_id)
+ print("seg_t :", seg_tokens_pre[0])
+
+ seg_embs_pre = [
+ self.model.llama_model.model.embed_tokens(seg_t)
+ for seg_t in seg_tokens_pre
+ ]
+ mixed_embs_pre = [
+ emb.to("cpu")
+ for pair in zip(seg_embs_pre, img_list)
+ for emb in pair
+ ]
+ mixed_embs_pre = torch.cat(mixed_embs_pre, dim=1)
+ max_allowed_tokens = max_allowed_tokens - mixed_embs_pre.shape[1]
+ final_prompt = prompt_segs[-1]
+ seg_tokens_post = [
+ self.model.llama_tokenizer(
+ seg,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_allowed_tokens,
+ add_special_tokens=False,
+ )
+ .to("cpu")
+ .input_ids
+ # only add bos to the first seg
+ for i, seg in enumerate([final_prompt])
+ ]
+ seg_tokens_post = seg_tokens_post[0]
+ seg_embs_post = [
+ self.model.llama_model.model.embed_tokens(seg_t)
+ for seg_t in seg_tokens_post
+ ]
+ mixed_embs_post = [seg_embs_post[0].to("cpu")]
+ mixed_embs_post = torch.unsqueeze(mixed_embs_post[0], 0)
+ mixed_embs = [mixed_embs_pre] + [mixed_embs_post]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ device = args.device
+ precision = args.precision
+ _compile = args.compile
+ max_length = args.max_length
+ max_new_tokens = args.max_new_tokens
+ print("Will run SHARK MultiModal for the following paramters :-\n")
+ print(
+ f"Device={device} precision={precision} compile={_compile} max_length={max_length} max_new_tokens={max_new_tokens}"
+ )
+
+ padding = max_length - max_new_tokens
+ assert (
+ padding > 0
+ ), "max_length should be strictly greater than max_new_tokens"
+
+ if args.image_path == "":
+ print(
+ "To run MiniGPT4 in CLI mode please provide an image's path using --image_path"
+ )
+ sys.exit()
+
+ vision_model_vmfb_path = (
+ Path("vision_model_fp16_cuda.vmfb")
+ if args.vision_model_vmfb_path is None
+ else Path(args.vision_model_vmfb_path)
+ )
+ qformer_vmfb_path = (
+ Path("qformer_fp32_cuda.vmfb")
+ if args.qformer_vmfb_path is None
+ else Path(args.qformer_vmfb_path)
+ )
+ chat = MiniGPT4(
+ model_name="MiniGPT4",
+ hf_model_path=None,
+ max_new_tokens=30,
+ device=device,
+ precision=precision,
+ _compile=_compile,
+ vision_model_vmfb_path=vision_model_vmfb_path,
+ qformer_vmfb_path=qformer_vmfb_path,
+ )
+
+ chat_state = CONV_VISION.copy()
+ img_list = []
+ chat.upload_img(args.image_path, chat_state, img_list)
+ print(
+ "Uploaded image successfully to the bot. You may now start chatting with the bot. Enter 'END' without quotes to end the interaction"
+ )
+ continue_execution = True
+
+ while continue_execution:
+ user_message = input("User: ")
+ if user_message == "END":
+ print("Bot: Good bye.\n")
+ break
+ chat.ask(user_message, chat_state)
+ bot_message = chat.answer(
+ conv=chat_state,
+ img_list=img_list,
+ num_beams=1,
+ temperature=1.0,
+ max_new_tokens=max_new_tokens,
+ max_length=max_length,
+ )[0]
+ print("Bot: ", bot_message)
+
+ del chat_state, img_list, chat
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py b/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
new file mode 100644
index 0000000000..6944a9dd91
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
@@ -0,0 +1,1297 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+from dataclasses import dataclass
+from typing import Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size,
+ config.hidden_size,
+ padding_idx=config.pad_token_id,
+ )
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps, device="cpu"
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids",
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
+ )
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[
+ :, past_key_values_length : seq_length + past_key_values_length
+ ].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if (
+ config.hidden_size % config.num_attention_heads != 0
+ and not hasattr(config, "embedding_size")
+ ):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(
+ config.hidden_size / config.num_attention_heads
+ )
+ self.all_head_size = (
+ self.num_attention_heads * self.attention_head_size
+ )
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(
+ 2 * config.max_position_embeddings - 1,
+ self.attention_head_size,
+ )
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(
+ self.key(encoder_hidden_states)
+ )
+ value_layer = self.transpose_for_scores(
+ self.value(encoder_hidden_states)
+ )
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(
+ query_layer, key_layer.transpose(-1, -2)
+ )
+
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(
+ distance + self.max_position_embeddings - 1
+ )
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype
+ ) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ relative_position_scores_key = torch.einsum(
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
+ )
+ attention_scores = (
+ attention_scores
+ + relative_position_scores_query
+ + relative_position_scores_key
+ )
+
+ attention_scores = attention_scores / math.sqrt(
+ self.attention_head_size
+ )
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (
+ self.all_head_size,
+ )
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs)
+ if output_attentions
+ else (context_layer,)
+ )
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(
+ heads
+ )
+ self.self.all_head_size = (
+ self.self.attention_head_size * self.self.num_attention_heads
+ )
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if (
+ self.config.add_cross_attention
+ and layer_num % self.config.cross_attention_freq == 0
+ ):
+ self.crossattention = BertAttention(
+ config, is_cross_attention=self.config.add_cross_attention
+ )
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ self.intermediate_query = BertIntermediate(config)
+ self.output_query = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ query_length=0,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = (
+ past_key_value[:2] if past_key_value is not None else None
+ )
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:-1]
+
+ present_key_value = self_attention_outputs[-1]
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ assert (
+ encoder_hidden_states is not None
+ ), "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ query_attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+ outputs = (
+ outputs + cross_attention_outputs[1:-1]
+ ) # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ )
+ layer_output = torch.cat(
+ [layer_output, layer_output_text], dim=1
+ )
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ query_length=0,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = (
+ ()
+ if output_attentions and self.config.add_cross_attention
+ else None
+ )
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = (
+ past_key_values[i] if past_key_values is not None else None
+ )
+
+ if (
+ getattr(self.config, "gradient_checkpointing", False)
+ and self.training
+ ):
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(
+ *inputs,
+ past_key_value,
+ output_attentions,
+ query_length
+ )
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ all_cross_attentions = all_cross_attentions + (
+ layer_outputs[2],
+ )
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=False
+ )
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(
+ mean=0.0, std=self.config.initializer_range
+ )
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ has_query: bool = False,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+
+ # add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = (
+ attention_mask.shape[1] - causal_mask.shape[1]
+ )
+ if has_query: # UniLM style attention mask
+ causal_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, prefix_seq_len, seq_length),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=1,
+ )
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (
+ batch_size,
+ causal_mask.shape[1],
+ prefix_seq_len,
+ ),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+ extended_attention_mask = (
+ causal_mask[:, None, :, :]
+ * attention_mask[:, None, None, :]
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict
+ if return_dict is not None
+ else self.config.use_return_dict
+ )
+
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is None:
+ assert (
+ query_embeds is not None
+ ), "You have to specify query_embeds when input_ids is None"
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] - self.config.query_length
+ if past_key_values is not None
+ else 0
+ )
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)),
+ device=device,
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ device,
+ is_decoder,
+ has_query=(query_embeds is not None),
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, device, is_decoder
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states[0].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (
+ encoder_batch_size,
+ encoder_sequence_length,
+ )
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [
+ self.invert_attention_mask(mask)
+ for mask in encoder_attention_mask
+ ]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(
+ encoder_hidden_shape, device=device
+ )
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(
+ head_mask, self.config.num_hidden_layers
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ query_length=query_length,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = (
+ self.pooler(sequence_output) if self.pooler is not None else None
+ )
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [
+ r"position_ids",
+ r"predictions.decoder.bias",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = (
+ return_dict
+ if return_dict is not None
+ else self.config.use_return_dict
+ )
+ if labels is not None:
+ use_cache = False
+ if past_key_values is not None:
+ query_embeds = None
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ sequence_output = outputs[0]
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[
+ :, :-1, :
+ ].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(
+ reduction=reduction, label_smoothing=0.1
+ )
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ query_embeds,
+ past=None,
+ attention_mask=None,
+ **model_kwargs
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "query_embeds": query_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get(
+ "encoder_hidden_states", None
+ ),
+ "encoder_attention_mask": model_kwargs.get(
+ "encoder_attention_mask", None
+ ),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx)
+ for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [
+ r"position_ids",
+ r"predictions.decoder.bias",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = (
+ return_dict
+ if return_dict is not None
+ else self.config.use_return_dict
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return (
+ ((masked_lm_loss,) + output)
+ if masked_lm_loss is not None
+ else output
+ )
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py b/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py
new file mode 100644
index 0000000000..8c10c65916
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py
@@ -0,0 +1,68 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+from omegaconf import OmegaConf
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+
+class BaseProcessor:
+ def __init__(self):
+ self.transform = lambda x: x
+ return
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ return cls()
+
+ def build(self, **kwargs):
+ cfg = OmegaConf.create(kwargs)
+
+ return self.from_config(cfg)
+
+
+class BlipImageBaseProcessor(BaseProcessor):
+ def __init__(self, mean=None, std=None):
+ if mean is None:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ if std is None:
+ std = (0.26862954, 0.26130258, 0.27577711)
+
+ self.normalize = transforms.Normalize(mean, std)
+
+
+class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
+ def __init__(self, image_size=224, mean=None, std=None):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (image_size, image_size),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 224)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ return cls(image_size=image_size, mean=mean, std=std)
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml
new file mode 100644
index 0000000000..5710834200
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml
@@ -0,0 +1,5 @@
+datasets:
+ cc_sbu_align:
+ data_type: images
+ build_info:
+ storage: /path/to/cc_sbu_align/
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml
new file mode 100644
index 0000000000..803d0a7ee4
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml
@@ -0,0 +1,33 @@
+model:
+ arch: mini_gpt4
+
+ # vit encoder
+ image_size: 224
+ drop_path_rate: 0
+ use_grad_checkpoint: False
+ vit_precision: "fp16"
+ freeze_vit: True
+ freeze_qformer: True
+
+ # Q-Former
+ num_query_token: 32
+
+ # Vicuna
+ llama_model: "lmsys/vicuna-7b-v1.3"
+
+ # generation configs
+ prompt: ""
+
+preprocess:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 224
+ eval:
+ name: "blip2_image_eval"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml
new file mode 100644
index 0000000000..e54f44cc64
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml
@@ -0,0 +1,25 @@
+model:
+ arch: mini_gpt4
+ model_type: pretrain_vicuna
+ freeze_vit: True
+ freeze_qformer: True
+ max_txt_len: 160
+ end_sym: "###"
+ low_resource: False
+ prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
+ prompt_template: '###Human: {} ###Assistant: '
+ ckpt: 'prerained_minigpt4_7b.pth'
+
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "blip2_image_eval"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+
+run:
+ task: image_text_pretrain
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py b/apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py
new file mode 100644
index 0000000000..8310a363ea
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py
@@ -0,0 +1,53 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+
+import torch
+import torch.distributed as dist
+import timm.models.hub as timm_hub
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ """
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
+ """
+
+ def get_cached_file_path():
+ # a hack to sync the file path across processes
+ parts = torch.hub.urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+ return cached_file
+
+ if is_main_process():
+ timm_hub.download_cached_file(url, check_hash, progress)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return get_cached_file_path()
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py b/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
new file mode 100644
index 0000000000..d3782e0d10
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
@@ -0,0 +1,627 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+from apps.language_models.src.pipelines.minigpt4_utils.dist_utils import (
+ download_cached_file,
+)
+
+
+def _cfg(url="", **kwargs):
+ return {
+ "url": url,
+ "num_classes": 1000,
+ "input_size": (3, 224, 224),
+ "pool_size": None,
+ "crop_pct": 0.9,
+ "interpolation": "bicubic",
+ "mean": (0.5, 0.5, 0.5),
+ "std": (0.5, 0.5, 0.5),
+ **kwargs,
+ }
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(
+ torch.meshgrid([coords_h, coords_w])
+ ) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += (
+ window_size[0] - 1
+ ) # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2,
+ dtype=relative_coords.dtype,
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(
+ -1
+ ) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer(
+ "relative_position_index", relative_position_index
+ )
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (
+ self.q_bias,
+ torch.zeros_like(self.v_bias, requires_grad=False),
+ self.v_bias,
+ )
+ )
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = (
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ )
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(
+ init_values * torch.ones((dim)), requires_grad=True
+ )
+ self.gamma_2 = nn.Parameter(
+ init_values * torch.ones((dim)), requires_grad=True
+ )
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(
+ self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
+ )
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(
+ self.gamma_1
+ * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
+ )
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (
+ img_size[0] // patch_size[0]
+ )
+ self.patch_shape = (
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1],
+ )
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2,
+ dtype=relative_coords.dtype,
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(
+ -1
+ ) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer(
+ "relative_position_index", relative_position_index
+ )
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ init_values=None,
+ use_abs_pos_emb=True,
+ use_rel_pos_bias=False,
+ use_shared_rel_pos_bias=False,
+ use_mean_pooling=True,
+ init_scale=0.001,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.image_size = img_size
+ self.num_classes = num_classes
+ self.num_features = (
+ self.embed_dim
+ ) = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + 1, embed_dim)
+ )
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
+ )
+ else:
+ self.rel_pos_bias = None
+ self.use_checkpoint = use_checkpoint
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.patch_shape
+ if use_rel_pos_bias
+ else None,
+ )
+ for i in range(depth)
+ ]
+ )
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=0.02)
+ trunc_normal_(self.cls_token, std=0.02)
+ # trunc_normal_(self.mask_token, std=.02)
+ # if isinstance(self.head, nn.Linear):
+ # trunc_normal_(self.head.weight, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ # if isinstance(self.head, nn.Linear):
+ # self.head.weight.data.mul_(init_scale)
+ # self.head.bias.data.mul_(init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=""):
+ self.num_classes = num_classes
+ self.head = (
+ nn.Linear(self.embed_dim, num_classes)
+ if num_classes > 0
+ else nn.Identity()
+ )
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(
+ batch_size, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = (
+ self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ )
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
+ else:
+ x = blk(x, rel_pos_bias)
+ return x
+
+ # x = self.norm(x)
+
+ # if self.fc_norm is not None:
+ # t = x[:, 1:, :]
+ # return self.fc_norm(t.mean(1))
+ # else:
+ # return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ # x = self.head(x)
+ return x
+
+ def get_intermediate_layers(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(
+ batch_size, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ features = []
+ rel_pos_bias = (
+ self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ )
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias)
+ features.append(x)
+
+ return features
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(
+ (pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
+ )
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size, orig_size, new_size, new_size)
+ )
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size, new_size),
+ mode="bicubic",
+ align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ # l.weight.data = l.weight.data.half()
+ l.weight.data = l.weight.data
+ if l.bias is not None:
+ # l.bias.data = l.bias.data.half()
+ l.bias.data = l.bias.data
+
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ # tensor = getattr(l, attr)
+ # if tensor is not None:
+ # tensor.data = tensor.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def create_eva_vit_g(
+ img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
+):
+ model = VisionTransformer(
+ img_size=img_size,
+ patch_size=14,
+ use_mean_pooling=False,
+ embed_dim=1408,
+ depth=39,
+ num_heads=1408 // 88,
+ mlp_ratio=4.3637,
+ qkv_bias=True,
+ drop_path_rate=drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ use_checkpoint=use_checkpoint,
+ )
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
+ cached_file = download_cached_file(url, check_hash=False, progress=True)
+ state_dict = torch.load(cached_file, map_location="cpu")
+ interpolate_pos_embed(model, state_dict)
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+ # print(incompatible_keys)
+
+ if precision == "fp16":
+ # model.to("cuda")
+ convert_weights_to_fp16(model)
+ return model
diff --git a/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt b/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt
new file mode 100644
index 0000000000..38ae75a9ce
--- /dev/null
+++ b/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt
@@ -0,0 +1,4 @@
+ Describe this image in detail.
+ Take a look at this image and describe what you notice.
+ Please provide a detailed description of the picture.
+ Could you describe the contents of this image for me?
\ No newline at end of file
diff --git a/apps/language_models/utils.py b/apps/language_models/utils.py
index 64fb25f387..c9892ed5dd 100644
--- a/apps/language_models/utils.py
+++ b/apps/language_models/utils.py
@@ -3,6 +3,7 @@
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
+from shark.shark_downloader import download_public_file
# expects a Path / str as arg
@@ -17,9 +18,23 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
return None
print("Loading vmfb from: ", vmfb_path)
+ print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module
+
+
+def get_vmfb_from_config(
+ shark_container, model, precision, device, vmfb_path, padding=None
+):
+ vmfb_url = (
+ f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
+ )
+ if padding:
+ vmfb_url = vmfb_url + f"_{padding}"
+ vmfb_url = vmfb_url + ".vmfb"
+ download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
+ return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
diff --git a/apps/stable_diffusion/shark_studio_imports.py b/apps/stable_diffusion/shark_studio_imports.py
index 811893d086..3d2e7e6a9a 100644
--- a/apps/stable_diffusion/shark_studio_imports.py
+++ b/apps/stable_diffusion/shark_studio_imports.py
@@ -33,6 +33,7 @@
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
+datas += collect_data_files('timm', include_py_files=True)
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py
index d8d6c16ffb..9e6e0fbce3 100644
--- a/apps/stable_diffusion/web/index.py
+++ b/apps/stable_diffusion/web/index.py
@@ -147,6 +147,7 @@ def resource_path(relative_path):
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
+ minigpt4_web,
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
@@ -212,8 +213,10 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
+ with gr.TabItem(label="MultiModal (Experimental)", id=8):
+ minigpt4_web.render()
if args.output_gallery:
- with gr.TabItem(label="Output Gallery", id=8) as og_tab:
+ with gr.TabItem(label="Output Gallery", id=9) as og_tab:
outputgallery_web.render()
# extra output gallery configuration
diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py
index 08755fa368..2c03ffa156 100644
--- a/apps/stable_diffusion/web/ui/__init__.py
+++ b/apps/stable_diffusion/web/ui/__init__.py
@@ -75,6 +75,7 @@
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
+from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,
diff --git a/apps/stable_diffusion/web/ui/minigpt4_ui.py b/apps/stable_diffusion/web/ui/minigpt4_ui.py
new file mode 100644
index 0000000000..1a49159cdc
--- /dev/null
+++ b/apps/stable_diffusion/web/ui/minigpt4_ui.py
@@ -0,0 +1,165 @@
+# ========================================
+# Gradio Setting
+# ========================================
+import gradio as gr
+from apps.language_models.src.pipelines.minigpt4_pipeline import (
+ MiniGPT4,
+ CONV_VISION,
+)
+
+chat = None
+
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return (
+ None,
+ gr.update(value=None, interactive=True),
+ gr.update(
+ placeholder="Please upload your image first", interactive=False
+ ),
+ gr.update(value="Upload & Start Chat", interactive=True),
+ chat_state,
+ img_list,
+ )
+
+
+def upload_img(gr_img, text_input, chat_state, device):
+ global chat
+ if chat is None:
+ from apps.language_models.src.pipelines.minigpt4_pipeline import (
+ MiniGPT4,
+ )
+
+ chat = MiniGPT4(
+ model_name="MiniGPT4",
+ hf_model_path=None,
+ max_new_tokens=30,
+ device=device,
+ precision="fp16",
+ )
+ if gr_img is None:
+ return None, None, gr.update(interactive=True), chat_state, None
+ chat_state = CONV_VISION.copy()
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ return (
+ gr.update(interactive=False),
+ gr.update(interactive=True, placeholder="Type and press Enter"),
+ gr.update(value="Start Chatting", interactive=False),
+ chat_state,
+ img_list,
+ )
+
+
+def gradio_ask(user_message, chatbot, chat_state):
+ if len(user_message) == 0:
+ return (
+ gr.update(
+ interactive=True, placeholder="Input should not be empty!"
+ ),
+ chatbot,
+ chat_state,
+ )
+ chat.ask(user_message, chat_state)
+ chatbot = chatbot + [[user_message, None]]
+ return "", chatbot, chat_state
+
+
+def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
+ llm_message = chat.answer(
+ conv=chat_state,
+ img_list=img_list,
+ num_beams=num_beams,
+ temperature=temperature,
+ max_new_tokens=300,
+ max_length=2000,
+ )[0]
+ print(llm_message)
+ print("************")
+ chatbot[-1][1] = llm_message
+ return chatbot, chat_state, img_list
+
+
+title = """MultiModal SHARK (experimental)
"""
+description = """Upload your images and start chatting!
"""
+article = """
+"""
+
+# TODO show examples below
+
+with gr.Blocks() as minigpt4_web:
+ gr.Markdown(title)
+ gr.Markdown(description)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ image = gr.Image(type="pil")
+ upload_button = gr.Button(
+ value="Upload & Start Chat",
+ interactive=True,
+ variant="primary",
+ )
+ clear = gr.Button("Restart")
+
+ num_beams = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ interactive=True,
+ label="beam search numbers)",
+ )
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=2.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ device = gr.Dropdown(
+ label="Device",
+ value="cuda",
+ # if enabled
+ # else "Only CUDA Supported for now",
+ choices=["cuda"],
+ interactive=False,
+ )
+
+ with gr.Column():
+ chat_state = gr.State()
+ img_list = gr.State()
+ chatbot = gr.Chatbot(label="MiniGPT-4")
+ text_input = gr.Textbox(
+ label="User",
+ placeholder="Please upload your image first",
+ interactive=False,
+ )
+
+ upload_button.click(
+ upload_img,
+ [image, text_input, chat_state, device],
+ [image, text_input, upload_button, chat_state, img_list],
+ )
+
+ text_input.submit(
+ gradio_ask,
+ [text_input, chatbot, chat_state],
+ [text_input, chatbot, chat_state],
+ ).then(
+ gradio_answer,
+ [chatbot, chat_state, img_list, num_beams, temperature],
+ [chatbot, chat_state, img_list],
+ )
+ clear.click(
+ gradio_reset,
+ [chat_state, img_list],
+ [chatbot, image, text_input, upload_button, chat_state, img_list],
+ queue=False,
+ )
diff --git a/process_skipfiles.py b/process_skipfiles.py
index 4c7ec1eede..7038f2cb2c 100644
--- a/process_skipfiles.py
+++ b/process_skipfiles.py
@@ -56,3 +56,12 @@
)
else:
print(line, end="")
+
+# For getting around timm's packaging.
+# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
+path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py")
+for line in fileinput.input(path_to_timm_activations, inplace=True):
+ if '@torch.jit.script' in line:
+ print('@torch.jit._script_if_tracing', end="\n")
+ else:
+ print(line, end="")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 1f1ccd2cdc..61850d1acb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -31,6 +31,7 @@ pywebview
sentencepiece
py-cpuinfo
tiktoken # for codegen
+timm # for MiniGPT4
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile