From aaefacf324c040d9c0cc1767841aaafaa2d420aa Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 16 Jun 2023 15:02:06 +0000 Subject: [PATCH] [MiniGPT4] Add MiniGPT4 to SHARK -- This is the first installment of MiniGPT4 in SHARK. Signed-off-by: Abhishek Varma --- .../src/model_wrappers/minigpt4.py | 433 ++++++ .../src/pipelines/minigpt4_pipeline.py | 1180 +++++++++++++++ .../src/pipelines/minigpt4_utils/Qformer.py | 1297 +++++++++++++++++ .../minigpt4_utils/blip_processors.py | 68 + .../minigpt4_utils/configs/cc_sbu_align.yaml | 5 + .../minigpt4_utils/configs/minigpt4.yaml | 33 + .../minigpt4_utils/configs/minigpt4_eval.yaml | 25 + .../pipelines/minigpt4_utils/dist_utils.py | 53 + .../src/pipelines/minigpt4_utils/eva_vit.py | 627 ++++++++ .../minigpt4_utils/prompts/alignment.txt | 4 + apps/language_models/utils.py | 15 + apps/stable_diffusion/shark_studio_imports.py | 1 + apps/stable_diffusion/web/index.py | 5 +- apps/stable_diffusion/web/ui/__init__.py | 1 + apps/stable_diffusion/web/ui/minigpt4_ui.py | 165 +++ process_skipfiles.py | 11 + requirements.txt | 1 + 17 files changed, 3923 insertions(+), 1 deletion(-) create mode 100644 apps/language_models/src/model_wrappers/minigpt4.py create mode 100644 apps/language_models/src/pipelines/minigpt4_pipeline.py create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/Qformer.py create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py create mode 100644 apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt create mode 100644 apps/stable_diffusion/web/ui/minigpt4_ui.py diff --git a/apps/language_models/src/model_wrappers/minigpt4.py b/apps/language_models/src/model_wrappers/minigpt4.py new file mode 100644 index 0000000000..6521a0a5fe --- /dev/null +++ b/apps/language_models/src/model_wrappers/minigpt4.py @@ -0,0 +1,433 @@ +import torch +import dataclasses +from enum import auto, Enum +from typing import List, Any +from transformers import StoppingCriteria, StoppingCriteriaList + + +class LayerNorm(torch.nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class VisionModel(torch.nn.Module): + def __init__(self, ln_vision, visual_encoder): + super().__init__() + self.ln_vision = ln_vision + self.visual_encoder = visual_encoder + + def forward(self, image): + image_embeds = self.ln_vision(self.visual_encoder(image)) + return image_embeds + + +class QformerBertModel(torch.nn.Module): + def __init__(self, qformer_bert): + super().__init__() + self.qformer_bert = qformer_bert + + def forward(self, query_tokens, image_embeds, image_atts): + query_output = self.qformer_bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + return query_output.last_hidden_state + + +class FirstLlamaModel(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + print("SHARK: Loading LLAMA Done") + + def forward(self, inputs_embeds, position_ids, attention_mask): + print("************************************") + print( + "inputs_embeds: ", + inputs_embeds.shape, + " dtype: ", + inputs_embeds.dtype, + ) + print( + "position_ids: ", + position_ids.shape, + " dtype: ", + position_ids.dtype, + ) + print( + "attention_mask: ", + attention_mask.shape, + " dtype: ", + attention_mask.dtype, + ) + print("************************************") + config = { + "inputs_embeds": inputs_embeds, + "position_ids": position_ids, + "past_key_values": None, + "use_cache": True, + "attention_mask": attention_mask, + } + output = self.model( + **config, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + return_vals = [] + return_vals.append(output.logits) + temp_past_key_values = output.past_key_values + for item in temp_past_key_values: + return_vals.append(item[0]) + return_vals.append(item[1]) + return tuple(return_vals) + + +class SecondLlamaModel(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + print("SHARK: Loading LLAMA Done") + + def forward( + self, + input_ids, + position_ids, + attention_mask, + i1, + i2, + i3, + i4, + i5, + i6, + i7, + i8, + i9, + i10, + i11, + i12, + i13, + i14, + i15, + i16, + i17, + i18, + i19, + i20, + i21, + i22, + i23, + i24, + i25, + i26, + i27, + i28, + i29, + i30, + i31, + i32, + i33, + i34, + i35, + i36, + i37, + i38, + i39, + i40, + i41, + i42, + i43, + i44, + i45, + i46, + i47, + i48, + i49, + i50, + i51, + i52, + i53, + i54, + i55, + i56, + i57, + i58, + i59, + i60, + i61, + i62, + i63, + i64, + ): + print("************************************") + print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype) + print( + "position_ids: ", + position_ids.shape, + " dtype: ", + position_ids.dtype, + ) + print( + "attention_mask: ", + attention_mask.shape, + " dtype: ", + attention_mask.dtype, + ) + print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape) + print("past_key_values dtype: ", i1.dtype) + print("************************************") + config = { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": ( + (i1, i2), + ( + i3, + i4, + ), + ( + i5, + i6, + ), + ( + i7, + i8, + ), + ( + i9, + i10, + ), + ( + i11, + i12, + ), + ( + i13, + i14, + ), + ( + i15, + i16, + ), + ( + i17, + i18, + ), + ( + i19, + i20, + ), + ( + i21, + i22, + ), + ( + i23, + i24, + ), + ( + i25, + i26, + ), + ( + i27, + i28, + ), + ( + i29, + i30, + ), + ( + i31, + i32, + ), + ( + i33, + i34, + ), + ( + i35, + i36, + ), + ( + i37, + i38, + ), + ( + i39, + i40, + ), + ( + i41, + i42, + ), + ( + i43, + i44, + ), + ( + i45, + i46, + ), + ( + i47, + i48, + ), + ( + i49, + i50, + ), + ( + i51, + i52, + ), + ( + i53, + i54, + ), + ( + i55, + i56, + ), + ( + i57, + i58, + ), + ( + i59, + i60, + ), + ( + i61, + i62, + ), + ( + i63, + i64, + ), + ), + "use_cache": True, + "attention_mask": attention_mask, + } + output = self.model( + **config, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + return_vals = [] + return_vals.append(output.logits) + temp_past_key_values = output.past_key_values + for item in temp_past_key_values: + return_vals.append(item[0]) + return_vals.append(item[1]) + return tuple(return_vals) + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +class StoppingCriteriaSub(StoppingCriteria): + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop) :])).item(): + return True + + return False + + +CONV_VISION = Conversation( + system="Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions.", + roles=("Human", "Assistant"), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) diff --git a/apps/language_models/src/pipelines/minigpt4_pipeline.py b/apps/language_models/src/pipelines/minigpt4_pipeline.py new file mode 100644 index 0000000000..13b2e25824 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_pipeline.py @@ -0,0 +1,1180 @@ +from apps.language_models.src.model_wrappers.minigpt4 import ( + LayerNorm, + VisionModel, + QformerBertModel, + FirstLlamaModel, + SecondLlamaModel, + StoppingCriteriaSub, + CONV_VISION, +) +from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase +from apps.language_models.utils import ( + get_vmfb_from_path, + get_vmfb_from_config, +) +from omegaconf import OmegaConf +from pathlib import Path +from shark.shark_downloader import download_public_file +from transformers import LlamaTokenizer, LlamaForCausalLM +from transformers import StoppingCriteriaList +from transformers.generation import GenerationConfig, LogitsProcessorList + +import re +import torch +import os +from PIL import Image +import sys + +# SHARK dependencies +from shark.shark_compile import ( + shark_compile_through_fx, +) +import random +import contextlib +from transformers import BertTokenizer +from transformers import LlamaTokenizer, LlamaForCausalLM +from transformers.generation import GenerationConfig, LogitsProcessorList +import copy +import tempfile + +# QFormer, eva_vit, blip_processor, dist_utils +from apps.language_models.src.pipelines.minigpt4_utils.Qformer import ( + BertConfig, + BertLMHeadModel, +) +from apps.language_models.src.pipelines.minigpt4_utils.dist_utils import ( + download_cached_file, +) +from apps.language_models.src.pipelines.minigpt4_utils.eva_vit import ( + create_eva_vit_g, +) +from apps.language_models.src.pipelines.minigpt4_utils.blip_processors import ( + Blip2ImageEvalProcessor, +) + +import argparse + +parser = argparse.ArgumentParser( + prog="MiniGPT4 runner", + description="runs MiniGPT4", +) + +parser.add_argument( + "--precision", "-p", default="fp16", help="fp32, fp16, int8, int4" +) +parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") +parser.add_argument( + "--vision_model_vmfb_path", + default=None, + help="path to vision model's vmfb", +) +parser.add_argument( + "--qformer_vmfb_path", + default=None, + help="path to qformer model's vmfb", +) +parser.add_argument( + "--image_path", + type=str, + default="", + help="path to the input image", +) +parser.add_argument( + "--load_mlir_from_shark_tank", + default=False, + action=argparse.BooleanOptionalAction, + help="download precompile mlir from shark tank", +) +parser.add_argument( + "--cli", + default=True, + action=argparse.BooleanOptionalAction, + help="Run model in cli mode", +) +parser.add_argument( + "--compile", + default=False, + action=argparse.BooleanOptionalAction, + help="Compile all models", +) +parser.add_argument( + "--max_length", + type=int, + default=2000, + help="Max length of the entire conversation", +) +parser.add_argument( + "--max_new_tokens", + type=int, + default=300, + help="Maximum no. of new tokens that can be generated for a query", +) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +class MiniGPT4BaseModel(torch.nn.Module): + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get( + "q_former_model", + "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + ) + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 32) + end_sym = cfg.get("end_sym", "\n") + + model = cls( + vit_model=vit_model, + q_former_model=q_former_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt["model"], strict=False) + + return model + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/minigpt4.yaml", + } + + def maybe_autocast(self, dtype=torch.float32): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + # enable_autocast = self.device != torch.device("cpu") + enable_autocast = True + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def init_vision_encoder( + self, + model_name, + img_size, + drop_path_rate, + use_grad_checkpoint, + precision, + ): + assert ( + model_name == "eva_clip_g" + ), "vit model must be eva_clip_g for current version of MiniGPT-4" + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def init_Qformer( + cls, num_query_token, vision_width, cross_attention_freq=2 + ): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = torch.nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_( + mean=0.0, std=encoder_config.initializer_range + ) + return Qformer, query_tokens + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + self.load_state_dict(state_dict, strict=False) + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym="\n", + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__() + self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource + + print("Loading VIT") + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, + img_size, + drop_path_rate, + use_grad_checkpoint, + vit_precision, + ) + if freeze_vit: + for _, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + for _, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + # logging.info("freeze vision encoder") + print("Loading VIT Done") + + print("Loading Q-Former") + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + + if freeze_qformer: + for _, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + self.query_tokens.requires_grad = False + # logging.info("freeze Qformer") + print("Loading Q-Former Done") + + print(f"Loading Llama model from {llama_model}") + self.llama_tokenizer = LlamaTokenizer.from_pretrained( + llama_model, use_fast=False + ) + # self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + + if self.low_resource: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={"": device_8bit}, + ) + else: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float32, + ) + + print( + "During init :-\nLlama model pad token : ", + self.llama_model.config.pad_token_id, + ) + print( + "Llama tokenizer pad token : ", self.llama_tokenizer.pad_token_id + ) + + for _, param in self.llama_model.named_parameters(): + param.requires_grad = False + print("Loading Llama Done") + + self.llama_proj = torch.nn.Linear( + self.Qformer.config.hidden_size, + self.llama_model.config.hidden_size, + ) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + if prompt_path: + with open(prompt_path, "r") as f: + raw_prompts = f.read().splitlines() + filted_prompts = [ + raw_prompt + for raw_prompt in raw_prompts + if "" in raw_prompt + ] + self.prompt_list = [ + prompt_template.format(p) for p in filted_prompts + ] + print("Load {} training prompts".format(len(self.prompt_list))) + print( + "Prompt Example \n{}".format(random.choice(self.prompt_list)) + ) + else: + self.prompt_list = [] + + +class MiniGPT4(SharkLLMBase): + def __init__( + self, + model_name, + hf_model_path=None, + max_new_tokens=300, + device="cuda", + precision="fp16", + _compile=False, + vision_model_vmfb_path=Path("vision_model_fp16_cuda.vmfb"), + qformer_vmfb_path=Path("qformer_fp32_cuda.vmfb"), + ) -> None: + self.model_name = model_name + self.shark_model = None + super().__init__(model_name, hf_model_path, max_new_tokens) + self.download_dependencies() + self.device = device + self.precision = precision + self._compile = _compile + + self.vision_model_vmfb_path = vision_model_vmfb_path + self.qformer_vmfb_path = qformer_vmfb_path + self.first_llama_vmfb_path = None + self.second_llama_vmfb_path = None + + print("Initializing Chat") + config = OmegaConf.load( + "apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml" + ) + model_config = OmegaConf.create() + model_config = OmegaConf.merge( + model_config, + OmegaConf.load( + "apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml" + ), + {"model": config["model"]}, + ) + model_config = model_config["model"] + model_config.device_8bit = 0 + model = MiniGPT4BaseModel.from_config(model_config).to("cpu") + datasets = config.get("datasets", None) + dataset_config = OmegaConf.create() + for dataset_name in datasets: + dataset_config_path = "apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml" + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path), + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + dataset_config = dataset_config["datasets"] + vis_processor_cfg = dataset_config.cc_sbu_align.vis_processor.train + vis_processor = Blip2ImageEvalProcessor.from_config(vis_processor_cfg) + print("Initialization complete") + + self.model = model + self.vis_processor = vis_processor + stop_words_ids = [ + torch.tensor([835]).to("cpu"), + torch.tensor([2277, 29937]).to("cpu"), + ] # '###' can be encoded in two different ways. + self.stopping_criteria = StoppingCriteriaList( + [StoppingCriteriaSub(stops=stop_words_ids)] + ) + + self.first_llama = None + self.second_llama = None + + def download_dependencies(self): + pretrained_file = "prerained_minigpt4_7b.pth" + pretrained_file_url = f"gs://shark_tank/MiniGPT4/{pretrained_file}" + if not os.path.isfile(pretrained_file): + download_public_file( + pretrained_file_url, + Path("prerained_minigpt4_7b.pth").absolute(), + single_file=True, + ) + + if os.path.isfile(pretrained_file): + print(f"File downloaded successfully: {pretrained_file}") + else: + print(f"Error downloading {pretrained_file}") + sys.exit() + + # Currently we're compiling VisionModel for fp32/cuda. + def compile_vision_model(self): + if not self._compile: + vmfb = get_vmfb_from_path( + self.vision_model_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + return vmfb + else: + vmfb = get_vmfb_from_config( + self.model_name, + "vision_model", + self.precision, + self.device, + self.vision_model_vmfb_path, + ) + if vmfb is not None: + return vmfb + + visionModel = VisionModel( + self.model.ln_vision, self.model.visual_encoder + ) + extended_model_name = f"vision_model_{self.precision}_{self.device}" + print(f"Going to compile {extended_model_name}") + # Inputs for VisionModel. + inputs = [torch.randint(3, (1, 3, 224, 224), dtype=torch.float32)] + is_f16 = False + if self.precision == "fp16": + is_f16 = True + shark_visionModel, _ = shark_compile_through_fx( + visionModel, + inputs, + extended_model_name=extended_model_name, + is_f16=is_f16, + f16_input_mask=None, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + print(f"Generated {extended_model_name}.vmfb") + return shark_visionModel + + def compile_qformer_model(self): + if not self._compile: + vmfb = get_vmfb_from_path( + self.qformer_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + return vmfb + else: + vmfb = get_vmfb_from_config( + self.model_name, + "qformer", + "fp32", + self.device, + self.qformer_vmfb_path, + ) + if vmfb is not None: + return vmfb + + qformerBertModel = QformerBertModel(self.model.Qformer.bert) + extended_model_name = f"qformer_fp32_{self.device}" + print(f"Going to compile {extended_model_name}") + # Inputs for QFormer. + inputs = [ + torch.randint(3, (1, 32, 768), dtype=torch.float32), + torch.randint(3, (1, 257, 1408), dtype=torch.float32), + torch.randint(3, (1, 257), dtype=torch.int64), + ] + is_f16 = False + f16_input_mask = [] + shark_QformerBertModel, _ = shark_compile_through_fx( + qformerBertModel, + inputs, + extended_model_name=extended_model_name, + is_f16=is_f16, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + print(f"Generated {extended_model_name}.vmfb") + return shark_QformerBertModel + + def compile_first_llama(self, padding): + self.first_llama_vmfb_path = Path( + f"first_llama_{self.precision}_{self.device}_{padding}.vmfb" + ) + if not self._compile: + vmfb = get_vmfb_from_path( + self.first_llama_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + self.first_llama = vmfb + return vmfb + else: + vmfb = get_vmfb_from_config( + self.model_name, + "first_llama", + self.precision, + self.device, + self.first_llama_vmfb_path, + padding, + ) + if vmfb is not None: + self.first_llama = vmfb + return vmfb + + firstLlamaModel = FirstLlamaModel( + copy.deepcopy(self.model.llama_model) + ) + extended_model_name = ( + f"first_llama_{self.precision}_{self.device}_{padding}" + ) + print(f"Going to compile {extended_model_name}") + # Inputs for FirstLlama. + inputs_embeds = torch.ones((1, padding, 4096), dtype=torch.float32) + position_ids = torch.ones((1, padding), dtype=torch.int64) + attention_mask = torch.ones((1, padding), dtype=torch.int32) + inputs = [inputs_embeds, position_ids, attention_mask] + is_f16 = False + f16_input_mask = [] + if self.precision == "fp16": + is_f16 = True + f16_input_mask = [True, False, False] + shark_firstLlamaModel, _ = shark_compile_through_fx( + firstLlamaModel, + inputs, + extended_model_name=extended_model_name, + is_f16=is_f16, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + print(f"Generated {extended_model_name}.vmfb") + self.first_llama = shark_firstLlamaModel + return shark_firstLlamaModel + + def compile_second_llama(self, padding): + self.second_llama_vmfb_path = Path( + f"second_llama_{self.precision}_{self.device}_{padding}.vmfb" + ) + if not self._compile: + vmfb = get_vmfb_from_path( + self.second_llama_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + self.second_llama = vmfb + return vmfb + else: + vmfb = get_vmfb_from_config( + self.model_name, + "second_llama", + self.precision, + self.device, + self.second_llama_vmfb_path, + padding, + ) + if vmfb is not None: + self.second_llama = vmfb + return vmfb + + secondLlamaModel = SecondLlamaModel( + copy.deepcopy(self.model.llama_model) + ) + extended_model_name = ( + f"second_llama_{self.precision}_{self.device}_{padding}" + ) + print(f"Going to compile {extended_model_name}") + # Inputs for SecondLlama. + input_ids = torch.zeros((1, 1), dtype=torch.int64) + position_ids = torch.zeros((1, 1), dtype=torch.int64) + attention_mask = torch.zeros((1, padding + 1), dtype=torch.int32) + past_key_value = [] + for i in range(64): + past_key_value.append( + torch.zeros(1, 32, padding, 128, dtype=torch.float32) + ) + inputs = [input_ids, position_ids, attention_mask, *past_key_value] + is_f16 = False + f16_input_mask = [] + if self.precision == "fp16": + is_f16 = True + f16_input_mask = [False, False, False] + for i in past_key_value: + f16_input_mask.append(True) + + shark_secondLlamaModel, _ = shark_compile_through_fx( + secondLlamaModel, + inputs, + extended_model_name=extended_model_name, + is_f16=is_f16, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + print(f"Generated {extended_model_name}.vmfb") + self.second_llama = shark_secondLlamaModel + return shark_secondLlamaModel + + # Not yet sure why to use this. + def compile(self): + pass + + # Going to use `answer` instead. + def generate(self, prompt): + pass + + # Might use within `answer`, if needed. + def generate_new_token(self, params): + pass + + # Not needed yet because MiniGPT4BaseModel already loads this - will revisit later, + # if required. + def get_tokenizer(self): + pass + + # DumDum func - doing the intended stuff already at MiniGPT4BaseModel, + # i.e load llama, etc. + def get_src_model(self): + pass + + def ask(self, text, conv): + if ( + len(conv.messages) > 0 + and conv.messages[-1][0] == conv.roles[0] + and conv.messages[-1][1][-6:] == "" + ): # last message is image. + conv.messages[-1][1] = " ".join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer( + self, + conv, + img_list, + max_new_tokens=300, + num_beams=1, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1, + temperature=1.0, + max_length=2000, + ): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb( + conv, img_list, max_length - max_new_tokens + ) + padding = max_length - max_new_tokens + + current_max_len = embs.shape[1] + max_new_tokens + + if current_max_len - max_length > 0: + print( + "Warning: The number of tokens in current conversation exceeds the max length. " + "The model will not see the contexts outside the range." + ) + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + + ######################################################################################################### + + generation_config = GenerationConfig.from_model_config( + self.model.llama_model.config + ) + kwargs = { + "inputs_embeds": embs, + "max_new_tokens": max_new_tokens, + "num_beams": num_beams, + "do_sample": True, + "min_length": min_length, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "length_penalty": length_penalty, + "temperature": temperature, + } + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + logits_processor = LogitsProcessorList() + stopping_criteria = self.stopping_criteria + inputs = None + ( + inputs_tensor, + model_input_name, + model_kwargs, + ) = self.model.llama_model._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs[ + "output_hidden_states" + ] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + generation_config.pad_token_id = ( + self.model.llama_tokenizer.pad_token_id + ) + pad_token_id = generation_config.pad_token_id + embs_for_pad_token_id = self.model.llama_model.model.embed_tokens( + torch.tensor([pad_token_id]) + ) + model_kwargs["attention_mask"] = torch.logical_not( + torch.tensor( + [ + torch.all( + torch.eq(inputs_tensor[:, d, :], embs_for_pad_token_id) + ).int() + for d in range(inputs_tensor.shape[1]) + ] + ).unsqueeze(0) + ).int() + attention_meta_data = (model_kwargs["attention_mask"][0] == 0).nonzero( + as_tuple=True + )[0] + first_zero = attention_meta_data[0].item() + last_zero = attention_meta_data[-1].item() + input_ids = ( + inputs_tensor + if model_input_name == "input_ids" + else model_kwargs.pop("input_ids") + ) + input_ids_seq_length = input_ids.shape[-1] + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + logits_warper = self.model.llama_model._get_logits_warper( + generation_config + ) + ( + input_ids, + model_kwargs, + ) = self.model.llama_model._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=False, + **model_kwargs, + ) + # DOUBT: stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = ( + logits_warper + if logits_warper is not None + else LogitsProcessorList() + ) + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) + if eos_token_id is not None + else None + ) + scores = None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones( + input_ids.shape[0], dtype=torch.long, device=input_ids.device + ) + i = 0 + timesRan = 0 + is_fp16 = True + llama_list = [] + isPyTorchVariant = False + while True: + print("****** Iteration %d ******" % (i)) + # prepare model inputs + model_inputs = ( + self.model.llama_model.prepare_inputs_for_generation( + input_ids, **model_kwargs + ) + ) + + # forward pass to get next token + if i == 0: + shark_inputs = [] + if is_fp16: + model_inputs["inputs_embeds"] = model_inputs[ + "inputs_embeds" + ].to(torch.float16) + shark_inputs.append(model_inputs["inputs_embeds"].detach()) + shark_inputs.append(model_inputs["position_ids"].detach()) + shark_inputs.append(model_inputs["attention_mask"].detach()) + + if self.first_llama is None: + self.compile_first_llama(padding) + outputs_shark = self.first_llama("forward", shark_inputs) + outputs = [] + for out_shark in outputs_shark: + outputs.append(torch.from_numpy(out_shark)) + del outputs_shark + else: + shark_inputs = [] + shark_inputs.append(model_inputs["input_ids"].detach()) + shark_inputs.append(model_inputs["position_ids"].detach()) + shark_inputs.append(model_inputs["attention_mask"].detach()) + for pkv in list(model_inputs["past_key_values"]): + shark_inputs.append(pkv.detach()) + if self.second_llama is None: + self.compile_second_llama(padding) + outputs_shark = self.second_llama("forward", shark_inputs) + outputs = [] + for out_shark in outputs_shark: + outputs.append(torch.from_numpy(out_shark)) + del outputs_shark + + outputs_logits = outputs[0] + next_token_logits = outputs_logits[:, -1, :] + if is_fp16: + next_token_logits = next_token_logits.to(torch.float32) + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + probs = torch.nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = ( + next_tokens * unfinished_sequences + + pad_token_id * (1 - unfinished_sequences) + ) + + # update generated ids, model inputs, and length for next step + outputs_for_update_func = {} + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = ( + self.model.llama_model._update_model_kwargs_for_generation( + outputs_for_update_func, + model_kwargs, + is_encoder_decoder=False, + ) + ) + model_kwargs["past_key_values"] = outputs[1:] + if timesRan >= 1: + tmp_attention_mask = torch.cat( + ( + model_kwargs["attention_mask"][:, :first_zero], + model_kwargs["attention_mask"][:, first_zero + 1 :], + ), + dim=1, + ) + model_kwargs["attention_mask"] = tmp_attention_mask + pkv_list = [] + for pkv_pair_tuple in model_kwargs["past_key_values"]: + x = torch.cat( + ( + pkv_pair_tuple[:, :, :first_zero, :], + pkv_pair_tuple[:, :, first_zero + 1 :, :], + ), + dim=2, + ) + if is_fp16: + x = x.to(torch.float16) + pkv_list.append(x) + model_kwargs["past_key_values"] = tuple(pkv_list) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores + ): + break + + i = i + 1 + timesRan += 1 + llama_list.clear() + output_token = input_ids[0] + + if ( + output_token[0] == 0 + ): # the model might output a unknow token at the beginning. remove it + output_token = output_token[1:] + if ( + output_token[0] == 1 + ): # some users find that there is a start token at the beginning. remove it + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode( + output_token, add_special_tokens=False + ) + output_text = output_text.split("###")[0] # remove the stop sign '###' + output_text = output_text.split("Assistant:")[-1].strip() + conv.messages[-1][1] = output_text + return output_text, output_token.cpu().numpy() + + def upload_img(self, image, conv, img_list): + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert("RGB") + image = self.vis_processor(raw_image).unsqueeze(0).to("cpu") + elif isinstance(image, Image.Image): + raw_image = image + image = self.vis_processor(raw_image).unsqueeze(0).to("cpu") + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + image = image.to("cpu") + + device = image.device + if self.model.low_resource: + self.model.vit_to_cpu() + image = image.to("cpu") + + with self.model.maybe_autocast(): + shark_visionModel = self.compile_vision_model() + if self.precision == "fp16": + image = image.to(torch.float16) + image_embeds = shark_visionModel("forward", (image,)) + # image_embeds = shark_visionModel.forward(image) + image_embeds = torch.from_numpy(image_embeds) + image_embeds = image_embeds.to(device).to(torch.float32) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long + ).to(device) + + query_tokens = self.model.query_tokens.expand( + image_embeds.shape[0], -1, -1 + ).to(device) + # if self.precision == "fp16": + # query_tokens = query_tokens.to(torch.float16) + shark_QformerBertModel = self.compile_qformer_model() + query_output = shark_QformerBertModel( + "forward", + ( + query_tokens, + image_embeds, + image_atts, + ), + ) + # query_output = shark_QformerBertModel.forward(query_tokens, image_embeds, image_atts) + query_output = torch.from_numpy(query_output) + + inputs_llama = self.model.llama_proj(query_output) + image_emb = inputs_llama + img_list.append(image_emb) + conv.append_message(conv.roles[0], "") + msg = "Received." + return msg + + # """ + def get_context_emb(self, conv, img_list, max_allowed_tokens=200): + self.model.llama_tokenizer.padding_side = "left" + prompt = conv.get_prompt() + prompt_segs = prompt.split("") + assert ( + len(prompt_segs) == len(img_list) + 1 + ), "Unmatched numbers of image placeholders and images." + prompt_segs_pre = prompt_segs[:-1] + seg_tokens_pre = [] + for i, seg in enumerate(prompt_segs_pre): + # only add bos to the first seg + if i == 0: + add_special_tokens = True + else: + add_special_tokens = False + stp = ( + self.model.llama_tokenizer( + seg, + return_tensors="pt", + add_special_tokens=add_special_tokens, + ) + .to("cpu") + .input_ids + ) + seg_tokens_pre.append(stp) + # seg_tokens_pre = [ + # self.model.llama_tokenizer( + # seg, return_tensors="pt", add_special_tokens=i == 0 + # ) + # .to("cpu") + # .input_ids + # for i, seg in enumerate(prompt_segs_pre) + # ] + print( + "Before :-\nLlama model pad token : ", + self.model.llama_model.config.pad_token_id, + ) + print( + "Llama tokenizer pad token : ", + self.model.llama_tokenizer.pad_token_id, + ) + self.model.llama_model.config.pad_token_id = ( + self.model.llama_tokenizer.pad_token_id + ) + print( + "After :-\nLlama model pad token : ", + self.model.llama_model.config.pad_token_id, + ) + print( + "Llama tokenizer pad token : ", + self.model.llama_tokenizer.pad_token_id, + ) + print("seg_t :", seg_tokens_pre[0]) + + seg_embs_pre = [ + self.model.llama_model.model.embed_tokens(seg_t) + for seg_t in seg_tokens_pre + ] + mixed_embs_pre = [ + emb.to("cpu") + for pair in zip(seg_embs_pre, img_list) + for emb in pair + ] + mixed_embs_pre = torch.cat(mixed_embs_pre, dim=1) + max_allowed_tokens = max_allowed_tokens - mixed_embs_pre.shape[1] + final_prompt = prompt_segs[-1] + seg_tokens_post = [ + self.model.llama_tokenizer( + seg, + return_tensors="pt", + padding="max_length", + max_length=max_allowed_tokens, + add_special_tokens=False, + ) + .to("cpu") + .input_ids + # only add bos to the first seg + for i, seg in enumerate([final_prompt]) + ] + seg_tokens_post = seg_tokens_post[0] + seg_embs_post = [ + self.model.llama_model.model.embed_tokens(seg_t) + for seg_t in seg_tokens_post + ] + mixed_embs_post = [seg_embs_post[0].to("cpu")] + mixed_embs_post = torch.unsqueeze(mixed_embs_post[0], 0) + mixed_embs = [mixed_embs_pre] + [mixed_embs_post] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + +if __name__ == "__main__": + args = parser.parse_args() + + device = args.device + precision = args.precision + _compile = args.compile + max_length = args.max_length + max_new_tokens = args.max_new_tokens + print("Will run SHARK MultiModal for the following paramters :-\n") + print( + f"Device={device} precision={precision} compile={_compile} max_length={max_length} max_new_tokens={max_new_tokens}" + ) + + padding = max_length - max_new_tokens + assert ( + padding > 0 + ), "max_length should be strictly greater than max_new_tokens" + + if args.image_path == "": + print( + "To run MiniGPT4 in CLI mode please provide an image's path using --image_path" + ) + sys.exit() + + vision_model_vmfb_path = ( + Path("vision_model_fp16_cuda.vmfb") + if args.vision_model_vmfb_path is None + else Path(args.vision_model_vmfb_path) + ) + qformer_vmfb_path = ( + Path("qformer_fp32_cuda.vmfb") + if args.qformer_vmfb_path is None + else Path(args.qformer_vmfb_path) + ) + chat = MiniGPT4( + model_name="MiniGPT4", + hf_model_path=None, + max_new_tokens=30, + device=device, + precision=precision, + _compile=_compile, + vision_model_vmfb_path=vision_model_vmfb_path, + qformer_vmfb_path=qformer_vmfb_path, + ) + + chat_state = CONV_VISION.copy() + img_list = [] + chat.upload_img(args.image_path, chat_state, img_list) + print( + "Uploaded image successfully to the bot. You may now start chatting with the bot. Enter 'END' without quotes to end the interaction" + ) + continue_execution = True + + while continue_execution: + user_message = input("User: ") + if user_message == "END": + print("Bot: Good bye.\n") + break + chat.ask(user_message, chat_state) + bot_message = chat.answer( + conv=chat_state, + img_list=img_list, + num_beams=1, + temperature=1.0, + max_new_tokens=max_new_tokens, + max_length=max_length, + )[0] + print("Bot: ", bot_message) + + del chat_state, img_list, chat diff --git a/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py b/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py new file mode 100644 index 0000000000..6944a9dd91 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py @@ -0,0 +1,1297 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +from dataclasses import dataclass +from typing import Tuple, Dict, Any + +import torch +from torch import Tensor, device, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, device="cpu" + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if ( + config.hidden_size % config.num_attention_heads != 0 + and not hasattr(config, "embedding_size") + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int( + config.hidden_size / config.num_attention_heads + ) + self.all_head_size = ( + self.num_attention_heads * self.attention_head_size + ) + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size, + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states) + ) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states) + ) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul( + query_layer, key_layer.transpose(-1, -2) + ) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size + ) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, + ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) + if output_attentions + else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads + ) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat( + [layer_output, layer_output_text], dim=1 + ) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () + if output_attentions and self.config.add_cross_attention + else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = ( + past_key_values[i] if past_key_values is not None else None + ) + + if ( + getattr(self.config, "gradient_checkpointing", False) + and self.training + ): + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, + past_key_value, + output_attentions, + query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range + ) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = ( + attention_mask.shape[1] - causal_mask.shape[1] + ) + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + ( + batch_size, + causal_mask.shape[1], + prefix_seq_len, + ), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] + * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device, + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states[0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = ( + encoder_batch_size, + encoder_sequence_length, + ) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask( + head_mask, self.config.num_hidden_layers + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [ + r"position_ids", + r"predictions.decoder.bias", + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[ + :, :-1, : + ].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1 + ) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + query_embeds, + past=None, + attention_mask=None, + **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get( + "encoder_hidden_states", None + ), + "encoder_attention_mask": model_kwargs.get( + "encoder_attention_mask", None + ), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [ + r"position_ids", + r"predictions.decoder.bias", + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) + if masked_lm_loss is not None + else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py b/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py new file mode 100644 index 0000000000..8c10c65916 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + +class Blip2ImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml new file mode 100644 index 0000000000..5710834200 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu_align: + data_type: images + build_info: + storage: /path/to/cc_sbu_align/ diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml new file mode 100644 index 0000000000..803d0a7ee4 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml @@ -0,0 +1,33 @@ +model: + arch: mini_gpt4 + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + freeze_qformer: True + + # Q-Former + num_query_token: 32 + + # Vicuna + llama_model: "lmsys/vicuna-7b-v1.3" + + # generation configs + prompt: "" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml new file mode 100644 index 0000000000..e54f44cc64 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml @@ -0,0 +1,25 @@ +model: + arch: mini_gpt4 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 160 + end_sym: "###" + low_resource: False + prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt" + prompt_template: '###Human: {} ###Assistant: ' + ckpt: 'prerained_minigpt4_7b.pth' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py b/apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py new file mode 100644 index 0000000000..8310a363ea --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py @@ -0,0 +1,53 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py b/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py new file mode 100644 index 0000000000..d3782e0d10 --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py @@ -0,0 +1,627 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +from apps.language_models.src.pipelines.minigpt4_utils.dist_utils import ( + download_cached_file, +) + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.9, + "interpolation": "bicubic", + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + **kwargs, + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack( + torch.meshgrid([coords_h, coords_w]) + ) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += ( + window_size[0] - 1 + ) # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, + dtype=relative_coords.dtype, + ) + relative_position_index[1:, 1:] = relative_coords.sum( + -1 + ) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer( + "relative_position_index", relative_position_index + ) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True + ) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True + ) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path( + self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) + ) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.gamma_1 + * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) + ) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * ( + img_size[0] // patch_size[0] + ) + self.patch_shape = ( + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, + dtype=relative_coords.dtype, + ) + relative_position_index[1:, 1:] = relative_coords.sum( + -1 + ) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer( + "relative_position_index", relative_position_index + ) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_values=None, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + use_mean_pooling=True, + init_scale=0.001, + use_checkpoint=False, + ): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim) + ) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_embed.patch_shape, num_heads=num_heads + ) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=self.patch_embed.patch_shape + if use_rel_pos_bias + else None, + ) + for i in range(depth) + ] + ) + # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) + # trunc_normal_(self.mask_token, std=.02) + # if isinstance(self.head, nn.Linear): + # trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() + + # if isinstance(self.head, nn.Linear): + # self.head.weight.data.mul_(init_scale) + # self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.embed_dim, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand( + batch_size, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = ( + self.rel_pos_bias() if self.rel_pos_bias is not None else None + ) + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x + + # x = self.norm(x) + + # if self.fc_norm is not None: + # t = x[:, 1:, :] + # return self.fc_norm(t.mean(1)) + # else: + # return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand( + batch_size, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = ( + self.rel_pos_bias() if self.rel_pos_bias is not None else None + ) + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5 + ) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + # l.weight.data = l.weight.data.half() + l.weight.data = l.weight.data + if l.bias is not None: + # l.bias.data = l.bias.data.half() + l.bias.data = l.bias.data + + # if isinstance(l, (nn.MultiheadAttention, Attention)): + # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + # tensor = getattr(l, attr) + # if tensor is not None: + # tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g( + img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16" +): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + num_heads=1408 // 88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + cached_file = download_cached_file(url, check_hash=False, progress=True) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model, state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + # print(incompatible_keys) + + if precision == "fp16": + # model.to("cuda") + convert_weights_to_fp16(model) + return model diff --git a/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt b/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt new file mode 100644 index 0000000000..38ae75a9ce --- /dev/null +++ b/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt @@ -0,0 +1,4 @@ + Describe this image in detail. + Take a look at this image and describe what you notice. + Please provide a detailed description of the picture. + Could you describe the contents of this image for me? \ No newline at end of file diff --git a/apps/language_models/utils.py b/apps/language_models/utils.py index 64fb25f387..c9892ed5dd 100644 --- a/apps/language_models/utils.py +++ b/apps/language_models/utils.py @@ -3,6 +3,7 @@ from torch._decomp import get_decompositions from typing import List from pathlib import Path +from shark.shark_downloader import download_public_file # expects a Path / str as arg @@ -17,9 +18,23 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect): return None print("Loading vmfb from: ", vmfb_path) + print("Device from get_vmfb_from_path - ", device) shark_module = SharkInference( None, device=device, mlir_dialect=mlir_dialect ) shark_module.load_module(vmfb_path) print("Successfully loaded vmfb") return shark_module + + +def get_vmfb_from_config( + shark_container, model, precision, device, vmfb_path, padding=None +): + vmfb_url = ( + f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}" + ) + if padding: + vmfb_url = vmfb_url + f"_{padding}" + vmfb_url = vmfb_url + ".vmfb" + download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True) + return get_vmfb_from_path(vmfb_path, device, "tm_tensor") diff --git a/apps/stable_diffusion/shark_studio_imports.py b/apps/stable_diffusion/shark_studio_imports.py index 7529148312..b0a3ff72b5 100644 --- a/apps/stable_diffusion/shark_studio_imports.py +++ b/apps/stable_diffusion/shark_studio_imports.py @@ -35,6 +35,7 @@ datas += collect_data_files("iree") datas += collect_data_files("google_cloud_storage") datas += collect_data_files("shark") +datas += collect_data_files("timm", include_py_files=True) datas += collect_data_files("tkinter") datas += collect_data_files("webview") datas += collect_data_files("sentencepiece") diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 8b377ec23a..0e4e1983bc 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -160,6 +160,7 @@ def resource_path(relative_path): modelmanager_sendto_outpaint, modelmanager_sendto_upscaler, stablelm_chat, + minigpt4_web, outputgallery_web, outputgallery_tab_select, outputgallery_watch, @@ -225,8 +226,10 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): stablelm_chat.render() with gr.TabItem(label="LoRA Training(Experimental)", id=7): lora_train_web.render() + with gr.TabItem(label="MultiModal (Experimental)", id=8): + minigpt4_web.render() if args.output_gallery: - with gr.TabItem(label="Output Gallery", id=8) as og_tab: + with gr.TabItem(label="Output Gallery", id=9) as og_tab: outputgallery_web.render() # extra output gallery configuration diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py index c983abbb5b..9b8e4ced67 100644 --- a/apps/stable_diffusion/web/ui/__init__.py +++ b/apps/stable_diffusion/web/ui/__init__.py @@ -78,6 +78,7 @@ stablelm_chat, llm_chat_api, ) +from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web from apps.stable_diffusion.web.ui.outputgallery_ui import ( outputgallery_web, outputgallery_tab_select, diff --git a/apps/stable_diffusion/web/ui/minigpt4_ui.py b/apps/stable_diffusion/web/ui/minigpt4_ui.py new file mode 100644 index 0000000000..1a49159cdc --- /dev/null +++ b/apps/stable_diffusion/web/ui/minigpt4_ui.py @@ -0,0 +1,165 @@ +# ======================================== +# Gradio Setting +# ======================================== +import gradio as gr +from apps.language_models.src.pipelines.minigpt4_pipeline import ( + MiniGPT4, + CONV_VISION, +) + +chat = None + + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return ( + None, + gr.update(value=None, interactive=True), + gr.update( + placeholder="Please upload your image first", interactive=False + ), + gr.update(value="Upload & Start Chat", interactive=True), + chat_state, + img_list, + ) + + +def upload_img(gr_img, text_input, chat_state, device): + global chat + if chat is None: + from apps.language_models.src.pipelines.minigpt4_pipeline import ( + MiniGPT4, + ) + + chat = MiniGPT4( + model_name="MiniGPT4", + hf_model_path=None, + max_new_tokens=30, + device=device, + precision="fp16", + ) + if gr_img is None: + return None, None, gr.update(interactive=True), chat_state, None + chat_state = CONV_VISION.copy() + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + return ( + gr.update(interactive=False), + gr.update(interactive=True, placeholder="Type and press Enter"), + gr.update(value="Start Chatting", interactive=False), + chat_state, + img_list, + ) + + +def gradio_ask(user_message, chatbot, chat_state): + if len(user_message) == 0: + return ( + gr.update( + interactive=True, placeholder="Input should not be empty!" + ), + chatbot, + chat_state, + ) + chat.ask(user_message, chat_state) + chatbot = chatbot + [[user_message, None]] + return "", chatbot, chat_state + + +def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): + llm_message = chat.answer( + conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000, + )[0] + print(llm_message) + print("************") + chatbot[-1][1] = llm_message + return chatbot, chat_state, img_list + + +title = """

MultiModal SHARK (experimental)

""" +description = """

Upload your images and start chatting!

""" +article = """

+""" + +# TODO show examples below + +with gr.Blocks() as minigpt4_web: + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(scale=0.5): + image = gr.Image(type="pil") + upload_button = gr.Button( + value="Upload & Start Chat", + interactive=True, + variant="primary", + ) + clear = gr.Button("Restart") + + num_beams = gr.Slider( + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + label="beam search numbers)", + ) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + device = gr.Dropdown( + label="Device", + value="cuda", + # if enabled + # else "Only CUDA Supported for now", + choices=["cuda"], + interactive=False, + ) + + with gr.Column(): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot(label="MiniGPT-4") + text_input = gr.Textbox( + label="User", + placeholder="Please upload your image first", + interactive=False, + ) + + upload_button.click( + upload_img, + [image, text_input, chat_state, device], + [image, text_input, upload_button, chat_state, img_list], + ) + + text_input.submit( + gradio_ask, + [text_input, chatbot, chat_state], + [text_input, chatbot, chat_state], + ).then( + gradio_answer, + [chatbot, chat_state, img_list, num_beams, temperature], + [chatbot, chat_state, img_list], + ) + clear.click( + gradio_reset, + [chat_state, img_list], + [chatbot, image, text_input, upload_button, chat_state, img_list], + queue=False, + ) diff --git a/process_skipfiles.py b/process_skipfiles.py index 4c7ec1eede..0ae7687d95 100644 --- a/process_skipfiles.py +++ b/process_skipfiles.py @@ -56,3 +56,14 @@ ) else: print(line, end="") + +# For getting around timm's packaging. +# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505 +path_to_timm_activations = Path( + get_python_lib() + "/timm/layers/activations_jit.py" +) +for line in fileinput.input(path_to_timm_activations, inplace=True): + if "@torch.jit.script" in line: + print("@torch.jit._script_if_tracing", end="\n") + else: + print(line, end="") diff --git a/requirements.txt b/requirements.txt index d3979161b7..399d5f163a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,7 @@ pywebview sentencepiece py-cpuinfo tiktoken # for codegen +timm # for MiniGPT4 # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile