Skip to content

Commit

Permalink
Update .spec for MiniGPT4's config files
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Jul 25, 2023
1 parent 6c4e22d commit 681a7ea
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 89 deletions.
46 changes: 24 additions & 22 deletions apps/language_models/src/pipelines/minigpt4_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from shark.shark_downloader import download_public_file
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteriaList
from transformers.generation import GenerationConfig, LogitsProcessorList

Expand All @@ -24,27 +25,23 @@
import os
from PIL import Image
import sys

import requests
# SHARK dependencies
from shark.shark_compile import (
shark_compile_through_fx,
)
import random
import contextlib
from transformers import BertTokenizer
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers.generation import GenerationConfig, LogitsProcessorList
import copy
import tempfile

# QFormer, eva_vit, blip_processor, dist_utils
# QFormer, eva_vit, blip_processor
from apps.language_models.src.pipelines.minigpt4_utils.Qformer import (
BertConfig,
BertLMHeadModel,
)
from apps.language_models.src.pipelines.minigpt4_utils.dist_utils import (
download_cached_file,
)
from apps.language_models.src.pipelines.minigpt4_utils.eva_vit import (
create_eva_vit_g,
)
Expand Down Expand Up @@ -349,7 +346,7 @@ def from_config(cls, cfg):
return model

PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/minigpt4.yaml",
"pretrain_vicuna": "minigpt4_utils/configs/minigpt4.yaml",
}

def maybe_autocast(self, dtype=torch.float32):
Expand Down Expand Up @@ -406,10 +403,13 @@ def init_Qformer(

def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
local_filename = "blip2_pretrained_flant5xxl.pth"
response = requests.get(url_or_filename)
if response.status_code == 200:
with open(local_filename, "wb") as f:
f.write(response.content)
print("File downloaded successfully.")
checkpoint = torch.load(local_filename, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
Expand Down Expand Up @@ -484,20 +484,20 @@ def __init__(
print("Loading Q-Former Done")

print(f"Loading Llama model from {llama_model}")
self.llama_tokenizer = LlamaTokenizer.from_pretrained(
llama_model, use_fast=False
self.llama_tokenizer = AutoTokenizer.from_pretrained(
llama_model, use_fast=False, legacy=False
)
# self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token

if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
self.llama_model = AutoModelForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map={"": device_8bit},
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
self.llama_model = AutoModelForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float32,
)
Expand Down Expand Up @@ -539,6 +539,12 @@ def __init__(
else:
self.prompt_list = []

def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)

class MiniGPT4(SharkLLMBase):
def __init__(
Expand Down Expand Up @@ -566,15 +572,11 @@ def __init__(
self.second_llama_vmfb_path = None

print("Initializing Chat")
config = OmegaConf.load(
"apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml"
)
config = OmegaConf.load(resource_path("minigpt4_utils/configs/minigpt4_eval.yaml"))
model_config = OmegaConf.create()
model_config = OmegaConf.merge(
model_config,
OmegaConf.load(
"apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml"
),
OmegaConf.load(resource_path("minigpt4_utils/configs/minigpt4.yaml")),
{"model": config["model"]},
)
model_config = model_config["model"]
Expand All @@ -583,7 +585,7 @@ def __init__(
datasets = config.get("datasets", None)
dataset_config = OmegaConf.create()
for dataset_name in datasets:
dataset_config_path = "apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml"
dataset_config_path = resource_path("minigpt4_utils/configs/cc_sbu_align.yaml")
dataset_config = OmegaConf.merge(
dataset_config,
OmegaConf.load(dataset_config_path),
Expand Down
53 changes: 0 additions & 53 deletions apps/language_models/src/pipelines/minigpt4_utils/dist_utils.py

This file was deleted.

18 changes: 10 additions & 8 deletions apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
import requests
from functools import partial

import torch
Expand All @@ -14,10 +15,6 @@
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

from apps.language_models.src.pipelines.minigpt4_utils.dist_utils import (
download_cached_file,
)


def _cfg(url="", **kwargs):
return {
Expand Down Expand Up @@ -596,7 +593,7 @@ def _convert_weights_to_fp16(l):

model.apply(_convert_weights_to_fp16)


def create_eva_vit_g(
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
):
Expand All @@ -614,12 +611,17 @@ def create_eva_vit_g(
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
cached_file = download_cached_file(url, check_hash=False, progress=True)
state_dict = torch.load(cached_file, map_location="cpu")

local_filename = "eva_vit_g.pth"
response = requests.get(url)
if response.status_code == 200:
with open(local_filename, "wb") as f:
f.write(response.content)
print("File downloaded successfully.")
state_dict = torch.load(local_filename, map_location="cpu")
interpolate_pos_embed(model, state_dict)

incompatible_keys = model.load_state_dict(state_dict, strict=False)
# print(incompatible_keys)

if precision == "fp16":
# model.to("cuda")
Expand Down
4 changes: 3 additions & 1 deletion apps/stable_diffusion/shark_studio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
sys.setrecursionlimit(sys.getrecursionlimit() * 5)

# python path for pyinstaller
pathex = [".", "./apps/language_models/langchain"]
pathex = [".", "./apps/language_models/langchain", "./apps/language_models/src/pipelines/minigpt4_utils"]

# datafiles for pyinstaller
datas = []
Expand Down Expand Up @@ -53,6 +53,8 @@
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
("../language_models/src/pipelines/minigpt4_utils/configs/*", "minigpt4_utils/configs"),
("../language_models/src/pipelines/minigpt4_utils/prompts/*", "minigpt4_utils/prompts"),
]


Expand Down
12 changes: 8 additions & 4 deletions apps/stable_diffusion/web/ui/minigpt4_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Gradio Setting
# ========================================
import gradio as gr
from apps.language_models.src.pipelines.minigpt4_pipeline import (
MiniGPT4,
CONV_VISION,
)
# from apps.language_models.src.pipelines.minigpt4_pipeline import (
# # MiniGPT4,
# CONV_VISION,
# )
from pathlib import Path

chat = None
Expand All @@ -31,6 +31,10 @@ def gradio_reset(chat_state, img_list):
def upload_img(gr_img, text_input, chat_state, device, precision, _compile):
global chat
if chat is None:
from apps.language_models.src.pipelines.minigpt4_pipeline import (
MiniGPT4,
CONV_VISION,
)
vision_model_precision = precision
if precision in ["int4", "int8"]:
vision_model_precision = "fp16"
Expand Down
2 changes: 1 addition & 1 deletion process_skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@
if "@torch.jit.script" in line:
print("@torch.jit._script_if_tracing", end="\n")
else:
print(line, end="")
print(line, end="")

0 comments on commit 681a7ea

Please sign in to comment.