Skip to content

Commit

Permalink
Add int8 support for MiniGPT4
Browse files Browse the repository at this point in the history
-- This commit adds int8 support for MiniGPT4.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma committed Jul 25, 2023
1 parent 03ca471 commit 6c4e22d
Show file tree
Hide file tree
Showing 4 changed files with 491 additions and 78 deletions.
69 changes: 65 additions & 4 deletions apps/language_models/src/model_wrappers/minigpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import StoppingCriteria


from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

class LayerNorm(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

Expand All @@ -15,10 +18,38 @@ def forward(self, x: torch.Tensor):


class VisionModel(torch.nn.Module):
def __init__(self, ln_vision, visual_encoder):
def __init__(self, ln_vision, visual_encoder, precision="fp32", weight_group_size=128):
super().__init__()
self.ln_vision = ln_vision
self.visual_encoder = visual_encoder
if precision in ["int4", "int8"]:
print("Vision Model applying weight quantization to ln_vision")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.ln_vision,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
print("Vision Model applying weight quantization to visual_encoder")
quantize_model(
self.visual_encoder,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")

def forward(self, image):
image_embeds = self.ln_vision(self.visual_encoder(image))
Expand All @@ -41,10 +72,25 @@ def forward(self, query_tokens, image_embeds, image_atts):


class FirstLlamaModel(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("First Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")

def forward(self, inputs_embeds, position_ids, attention_mask):
print("************************************")
Expand Down Expand Up @@ -90,10 +136,25 @@ def forward(self, inputs_embeds, position_ids, attention_mask):


class SecondLlamaModel(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("Second Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")

def forward(
self,
Expand Down
Loading

0 comments on commit 6c4e22d

Please sign in to comment.