Skip to content

Commit

Permalink
Add minigpt4 adapter (#274)
Browse files Browse the repository at this point in the history
Signed-off-by: shiyu22 <[email protected]>
  • Loading branch information
shiyu22 authored Apr 24, 2023
1 parent 49d18cc commit 9e45404
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 1 deletion.
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ ignore:
- ".git"
- "*.yml"
- "*.md"
- "**/minigpt4.py"
22 changes: 22 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
# Example

- [How to run Visual Question Answering with MiniGPT-4](#How-to-run-Visual-Question-Answering-with-MiniGPT-4)
- [How to set the `embedding` function](#How-to-set-the-embedding-function)
- [How to set the `data manager` class](#How-to-set-the-data-manager-class)
- [How to set the `similarity evaluation` interface](#How-to-set-the-similarity-evaluation-interface)
- [Other cache init params](#Other-cache-init-params)
- [Benchmark](#Benchmark)

## How to run Visual Question Answering with MiniGPT-4

You can run [vqa_demo.py](./vqa_demo.py) to implement the image Q&A, which uses MiniGPT-4 for generating answers and then GPTCache to cache the answers.

> Note that you need to make sure that [minigpt4](https://github.com/Vision-CAIR/MiniGPT-4) and [gptcache](https://gptcache.readthedocs.io/en/dev/index.html) are successfully installed, and move the **vqa_demo.py** file to the MiniGPT-4 directory.
```bash
$ python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```

The above command will use the exact match cache, i.e. map cache management method. When you ask the same image and question, it will hit the cache directly and return the answer quickly.

If you want to use similar search cache, you can run the following command to set `map` to `False`, which will use sqlite3 and faiss to manage the cache to search for similar images and questions in the cache.

> You can also set `dir` to your workspace directory.
```bash
$ python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 --dir /path/to/workspace --no-map
```


## How to set the `embedding` function

> Please note that not all data managers are compatible with an embedding function.
Expand Down
92 changes: 92 additions & 0 deletions examples/vqa_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# ================================================================================
# This demo comes from [minigpt4](https://github.com/Vision-CAIR/MiniGPT-4)
# and is integrated with [gptcahe](https://github.com/zilliztech/GPTCache)
# for image Question Answering.
# Please make sure you have successfully setup minigpt4.
# Run `python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0`.
# =================================================================================

import argparse

import gradio as gr

from gptcache import cache
from gptcache.processor.pre import get_image, get_image_question
from gptcache.embedding import Timm
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.manager.factory import manager_factory

from gptcache.adapter.minigpt4 import MiniGPT4


def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument("--dir", type=str, default=".", help="path for data storage.")
parser.add_argument("--map", action='store_true', help="use map for exact match cache.")
parser.add_argument('--no-map', dest='map', action='store_false', help="use sqlite and faiss for similar search cache.")
parser.set_defaults(map=True)
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args


args = parse_args()

print("Initializing GPTCache")
if args.map:
data_manager = manager_factory("map", args.dir)
cache.init(
pre_embedding_func=get_image_question,
data_manager=data_manager
) # init with map method
else:
timm = Timm()
data_manager = manager_factory("sqlite,faiss", args.dir, vector_params={"dimension": timm.dimension})
cache.init(
pre_embedding_func=get_image,
data_manager=data_manager,
embedding_func=timm.to_embeddings,
similarity_evaluation=SearchDistanceEvaluation()
)
print("GPTCache Initialization Finished")

print("Initializing Chat")
pipeline = MiniGPT4.from_pretrained(cfg_path=args.cfg_path, gpu_id=args.gpu_id, options=args.options, return_hit=True)
print(" Chat Initialization Finished")


# ========================================
# Gradio Setting
# ========================================


title = """<h1 align="center">Demo of MiniGPT-4 and GPTCache</h1>"""
description = """<h3>This is the demo of MiniGPT-4 and GPTCache. Upload your images and ask question, and it will be cached.</h3>"""
article = """<p><a href="https://github.com/zilliztech/GPTCache"><img src="https://img.shields.io/badge/Github-Code-blue"></a></p>"""

# show examples below


with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column():
inp0 = gr.Image(source="upload", type="filepath")
inp1 = gr.Textbox(label="Question")
with gr.Column():
out0 = gr.Textbox()
out1 = gr.Textbox(label="is hit")
btn = gr.Button("Submit")
btn.click(fn=pipeline, inputs=[inp0, inp1], outputs=[out0, out1])

demo.launch(share=True)
90 changes: 90 additions & 0 deletions gptcache/adapter/minigpt4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from gptcache.adapter.adapter import adapt
from gptcache.utils.error import CacheError
from gptcache.manager.scalar_data.base import DataType, Question, Answer

from argparse import Namespace

from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

# pylint: disable=wildcard-import
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *


class MiniGPT4: # pragma: no cover
"""MiniGPT4 Wrapper
Example:
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_image_question
from gptcache.adapter.minigpt4 import MiniGPT4
# init gptcache
cache.init(pre_embedding_func=get_image_question)
# run with gptcache
pipe = MiniGPT4.from_pretrained(cfg_path='eval_configs/minigpt4_eval.yaml', gpu_id=3, options=None)
question = "Which city is this photo taken?"
image = "./merlion.png"
answer = pipe(image, question)
"""
def __init__(self, chat, return_hit):
self.chat = chat
self.return_hit = return_hit

@classmethod
def from_pretrained(cls, cfg_path, gpu_id=0, options=None, return_hit=False):
args = Namespace(cfg_path=cfg_path, gpu_id=gpu_id, options=options)
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to("cuda:{}".format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device="cuda:{}".format(args.gpu_id))
return cls(chat, return_hit)

def llm_handler(self, image, question):
chat_state = CONV_VISION.copy()
img_list = []
try:
self.chat.upload_img(image, chat_state, img_list)
self.chat.ask(question, chat_state)
answer = self.chat.answer(conv=chat_state, img_list=img_list)[0]
return answer if not self.return_hit else answer, False
except Exception as e:
raise CacheError("minigpt4 error") from e

def __call__(self, image, question, *args, **kwargs):
cache_context = {"deps": [
{"name": "text", "data": question, "dep_type": DataType.STR},
{"name": "image", "data": image, "dep_type": DataType.STR},
]}

def cache_data_convert(cache_data):
return cache_data if not self.return_hit else cache_data, True

def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
question_data = Question.from_dict({
"content": "pre_embedding_data",
"deps": [
{"name": "text", "data": kwargs["question"], "dep_type": DataType.STR},
{"name": "image", "data": kwargs["image"], "dep_type": DataType.STR},
]
})
llm_data_cache = llm_data if not self.return_hit else llm_data[0]
update_cache_func(Answer(llm_data_cache, DataType.STR), question=question_data)
return llm_data

return adapt(
self.llm_handler, cache_data_convert, update_cache_callback, image=image, question=question, cache_context=cache_context, *args, **kwargs
)
2 changes: 1 addition & 1 deletion gptcache/manager/scalar_data/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class AnswerTable(Base):
else:
id = Column(Integer, primary_key=True, autoincrement=True)
question_id = Column(Integer, nullable=False)
answer = Column(String(1000), nullable=False)
answer = Column(String(2000), nullable=False)
answer_type = Column(Integer, nullable=False)

class QuestionDepTable(Base):
Expand Down
10 changes: 10 additions & 0 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,13 @@ def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
input_data = data.get("input")
return input_data["image"].name


def get_image_question(data: Dict[str, Any], **_: Dict[str, Any]) -> str: # pragma: no cover
img = data.get("image")
data_img = str(open(img, "rb").peek()) if isinstance(img, str) else str(img) # pylint: disable=consider-using-with
return data_img + data.get("question")


def get_image(data: Dict[str, Any], **_: Dict[str, Any]) -> str: # pragma: no cover
return data.get("image")

0 comments on commit 9e45404

Please sign in to comment.