Skip to content

Commit

Permalink
black format MiniGPT4
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Jul 25, 2023
1 parent 681a7ea commit ebc974c
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 31 deletions.
13 changes: 11 additions & 2 deletions apps/language_models/src/model_wrappers/minigpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl


class LayerNorm(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

Expand All @@ -18,7 +19,13 @@ def forward(self, x: torch.Tensor):


class VisionModel(torch.nn.Module):
def __init__(self, ln_vision, visual_encoder, precision="fp32", weight_group_size=128):
def __init__(
self,
ln_vision,
visual_encoder,
precision="fp32",
weight_group_size=128,
):
super().__init__()
self.ln_vision = ln_vision
self.visual_encoder = visual_encoder
Expand All @@ -37,7 +44,9 @@ def __init__(self, ln_vision, visual_encoder, precision="fp32", weight_group_siz
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
print("Vision Model applying weight quantization to visual_encoder")
print(
"Vision Model applying weight quantization to visual_encoder"
)
quantize_model(
self.visual_encoder,
dtype=torch.float32,
Expand Down
59 changes: 47 additions & 12 deletions apps/language_models/src/pipelines/minigpt4_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from PIL import Image
import sys
import requests

# SHARK dependencies
from shark.shark_compile import (
shark_compile_through_fx,
Expand Down Expand Up @@ -107,6 +108,7 @@
help="Maximum no. of new tokens that can be generated for a query",
)


def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
Expand All @@ -120,6 +122,7 @@ def is_url(input_url):
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
return is_url


import os
import tempfile
from shark.shark_inference import SharkInference
Expand All @@ -132,6 +135,7 @@ def is_url(input_url):
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl


def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
Expand Down Expand Up @@ -170,6 +174,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module


def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
Expand All @@ -191,7 +196,9 @@ def compile_module(
return shark_module


def compile_int_precision(model, inputs, precision, device, generate_vmfb, extended_model_name):
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name
):
torchscript_module = import_with_fx(
model,
inputs,
Expand All @@ -215,8 +222,10 @@ def compile_int_precision(model, inputs, precision, device, generate_vmfb, exten
)
from contextlib import redirect_stdout

mlir_file_path = os.path.join(os.getcwd(), f"{extended_model_name}_linalg.mlir")
with open(mlir_file_path, 'w') as f:
mlir_file_path = os.path.join(
os.getcwd(), f"{extended_model_name}_linalg.mlir"
)
with open(mlir_file_path, "w") as f:
with redirect_stdout(f):
print(mlir_module.operation.get_asm())
mlir_module = str(mlir_module)
Expand All @@ -235,10 +244,16 @@ def compile_int_precision(model, inputs, precision, device, generate_vmfb, exten
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
return (
compile_module(shark_module, extended_model_name=extended_model_name, generate_vmfb=generate_vmfb, extra_args=extra_args),
bytecode
compile_module(
shark_module,
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
),
bytecode,
)


def shark_compile_through_fx_int(
model,
inputs,
Expand Down Expand Up @@ -270,7 +285,14 @@ def shark_compile_through_fx_int(
if "cuda" in device:
shark_args.enable_tf32 = True

mlir_module = compile_int_precision(model, inputs, precision, device, generate_or_load_vmfb, extended_model_name)
mlir_module = compile_int_precision(
model,
inputs,
precision,
device,
generate_or_load_vmfb,
extended_model_name,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
Expand All @@ -292,7 +314,8 @@ def shark_compile_through_fx_int(
),
mlir_module,
)



class MiniGPT4BaseModel(torch.nn.Module):
@classmethod
def from_config(cls, cfg):
Expand Down Expand Up @@ -539,13 +562,15 @@ def __init__(
else:
self.prompt_list = []


def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)


class MiniGPT4(SharkLLMBase):
def __init__(
self,
Expand All @@ -572,11 +597,15 @@ def __init__(
self.second_llama_vmfb_path = None

print("Initializing Chat")
config = OmegaConf.load(resource_path("minigpt4_utils/configs/minigpt4_eval.yaml"))
config = OmegaConf.load(
resource_path("minigpt4_utils/configs/minigpt4_eval.yaml")
)
model_config = OmegaConf.create()
model_config = OmegaConf.merge(
model_config,
OmegaConf.load(resource_path("minigpt4_utils/configs/minigpt4.yaml")),
OmegaConf.load(
resource_path("minigpt4_utils/configs/minigpt4.yaml")
),
{"model": config["model"]},
)
model_config = model_config["model"]
Expand All @@ -585,7 +614,9 @@ def __init__(
datasets = config.get("datasets", None)
dataset_config = OmegaConf.create()
for dataset_name in datasets:
dataset_config_path = resource_path("minigpt4_utils/configs/cc_sbu_align.yaml")
dataset_config_path = resource_path(
"minigpt4_utils/configs/cc_sbu_align.yaml"
)
dataset_config = OmegaConf.merge(
dataset_config,
OmegaConf.load(dataset_config_path),
Expand Down Expand Up @@ -651,9 +682,13 @@ def compile_vision_model(self):
return vmfb

visionModel = VisionModel(
copy.deepcopy(self.model.ln_vision), copy.deepcopy(self.model.visual_encoder), vision_model_precision
copy.deepcopy(self.model.ln_vision),
copy.deepcopy(self.model.visual_encoder),
vision_model_precision,
)
extended_model_name = (
f"vision_model_{vision_model_precision}_{self.device}"
)
extended_model_name = f"vision_model_{vision_model_precision}_{self.device}"
print(f"Going to compile {extended_model_name}")
# Inputs for VisionModel.
inputs = [torch.randint(3, (1, 3, 224, 224), dtype=torch.float32)]
Expand Down
4 changes: 2 additions & 2 deletions apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _convert_weights_to_fp16(l):

model.apply(_convert_weights_to_fp16)


def create_eva_vit_g(
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
):
Expand All @@ -611,7 +611,7 @@ def create_eva_vit_g(
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"

local_filename = "eva_vit_g.pth"
response = requests.get(url)
if response.status_code == 200:
Expand Down
16 changes: 13 additions & 3 deletions apps/stable_diffusion/shark_studio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
sys.setrecursionlimit(sys.getrecursionlimit() * 5)

# python path for pyinstaller
pathex = [".", "./apps/language_models/langchain", "./apps/language_models/src/pipelines/minigpt4_utils"]
pathex = [
".",
"./apps/language_models/langchain",
"./apps/language_models/src/pipelines/minigpt4_utils",
]

# datafiles for pyinstaller
datas = []
Expand Down Expand Up @@ -53,8 +57,14 @@
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
("../language_models/src/pipelines/minigpt4_utils/configs/*", "minigpt4_utils/configs"),
("../language_models/src/pipelines/minigpt4_utils/prompts/*", "minigpt4_utils/prompts"),
(
"../language_models/src/pipelines/minigpt4_utils/configs/*",
"minigpt4_utils/configs",
),
(
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
"minigpt4_utils/prompts",
),
]


Expand Down
6 changes: 5 additions & 1 deletion apps/stable_diffusion/web/ui/minigpt4_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Gradio Setting
# ========================================
import gradio as gr

# from apps.language_models.src.pipelines.minigpt4_pipeline import (
# # MiniGPT4,
# CONV_VISION,
Expand Down Expand Up @@ -35,10 +36,13 @@ def upload_img(gr_img, text_input, chat_state, device, precision, _compile):
MiniGPT4,
CONV_VISION,
)

vision_model_precision = precision
if precision in ["int4", "int8"]:
vision_model_precision = "fp16"
vision_model_vmfb_path = Path(f"vision_model_{vision_model_precision}_{device}.vmfb")
vision_model_vmfb_path = Path(
f"vision_model_{vision_model_precision}_{device}.vmfb"
)
qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb")
chat = MiniGPT4(
model_name="MiniGPT4",
Expand Down
2 changes: 1 addition & 1 deletion process_skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@
if "@torch.jit.script" in line:
print("@torch.jit._script_if_tracing", end="\n")
else:
print(line, end="")
print(line, end="")
55 changes: 45 additions & 10 deletions shark/shark_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:

def brevitas〇matmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
Expand All @@ -19,20 +27,30 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")


def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def brevitas〇matmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype


def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
return


brevitas_matmul_rhs_group_quant_library = [
brevitas〇matmul_rhs_group_quant〡shape,
brevitas〇matmul_rhs_group_quant〡dtype,
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
]


def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
Expand Down Expand Up @@ -70,7 +88,9 @@ def compile_module(
return shark_module


def compile_int_precision(model, inputs, precision, device, generate_vmfb, extended_model_name):
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name
):
weight_bit_width = 4 if precision == "int4" else 8
weight_group_size = 128
quantize_model(
Expand Down Expand Up @@ -115,8 +135,10 @@ def compile_int_precision(model, inputs, precision, device, generate_vmfb, exten
)
from contextlib import redirect_stdout

mlir_file_path = os.path.join(os.getcwd(), f"{extended_model_name}_linalg.mlir")
with open(mlir_file_path, 'w') as f:
mlir_file_path = os.path.join(
os.getcwd(), f"{extended_model_name}_linalg.mlir"
)
with open(mlir_file_path, "w") as f:
with redirect_stdout(f):
print(mlir_module.operation.get_asm())
mlir_module = str(mlir_module)
Expand All @@ -135,10 +157,16 @@ def compile_int_precision(model, inputs, precision, device, generate_vmfb, exten
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
return (
compile_module(shark_module, extended_model_name=extended_model_name, generate_vmfb=generate_vmfb, extra_args=extra_args),
bytecode
compile_module(
shark_module,
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
),
bytecode,
)


def shark_compile_through_fx(
model,
inputs,
Expand Down Expand Up @@ -172,7 +200,14 @@ def shark_compile_through_fx(
shark_args.enable_tf32 = True

if precision in ["int4", "int8"]:
mlir_module = compile_int_precision(model, inputs, precision, device, generate_or_load_vmfb, extended_model_name)
mlir_module = compile_int_precision(
model,
inputs,
precision,
device,
generate_or_load_vmfb,
extended_model_name,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
Expand Down

0 comments on commit ebc974c

Please sign in to comment.