-
Notifications
You must be signed in to change notification settings - Fork 510
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: shiyu22 <[email protected]>
- Loading branch information
Showing
6 changed files
with
216 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,4 @@ ignore: | |
- ".git" | ||
- "*.yml" | ||
- "*.md" | ||
- "**/minigpt4.py" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# ================================================================================ | ||
# This demo comes from [minigpt4](https://github.com/Vision-CAIR/MiniGPT-4) | ||
# and is integrated with [gptcahe](https://github.com/zilliztech/GPTCache) | ||
# for image Question Answering. | ||
# Please make sure you have successfully setup minigpt4. | ||
# Run `python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0`. | ||
# ================================================================================= | ||
|
||
import argparse | ||
|
||
import gradio as gr | ||
|
||
from gptcache import cache | ||
from gptcache.processor.pre import get_image, get_image_question | ||
from gptcache.embedding import Timm | ||
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation | ||
from gptcache.manager.factory import manager_factory | ||
|
||
from gptcache.adapter.minigpt4 import MiniGPT4 | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Demo") | ||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | ||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") | ||
parser.add_argument("--dir", type=str, default=".", help="path for data storage.") | ||
parser.add_argument("--map", action='store_true', help="use map for exact match cache.") | ||
parser.add_argument('--no-map', dest='map', action='store_false', help="use sqlite and faiss for similar search cache.") | ||
parser.set_defaults(map=True) | ||
parser.add_argument( | ||
"--options", | ||
nargs="+", | ||
help="override some settings in the used config, the key-value pair " | ||
"in xxx=yyy format will be merged into config file (deprecate), " | ||
"change to --cfg-options instead.", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
args = parse_args() | ||
|
||
print("Initializing GPTCache") | ||
if args.map: | ||
data_manager = manager_factory("map", args.dir) | ||
cache.init( | ||
pre_embedding_func=get_image_question, | ||
data_manager=data_manager | ||
) # init with map method | ||
else: | ||
timm = Timm() | ||
data_manager = manager_factory("sqlite,faiss", args.dir, vector_params={"dimension": timm.dimension}) | ||
cache.init( | ||
pre_embedding_func=get_image, | ||
data_manager=data_manager, | ||
embedding_func=timm.to_embeddings, | ||
similarity_evaluation=SearchDistanceEvaluation() | ||
) | ||
print("GPTCache Initialization Finished") | ||
|
||
print("Initializing Chat") | ||
pipeline = MiniGPT4.from_pretrained(cfg_path=args.cfg_path, gpu_id=args.gpu_id, options=args.options, return_hit=True) | ||
print(" Chat Initialization Finished") | ||
|
||
|
||
# ======================================== | ||
# Gradio Setting | ||
# ======================================== | ||
|
||
|
||
title = """<h1 align="center">Demo of MiniGPT-4 and GPTCache</h1>""" | ||
description = """<h3>This is the demo of MiniGPT-4 and GPTCache. Upload your images and ask question, and it will be cached.</h3>""" | ||
article = """<p><a href="https://github.com/zilliztech/GPTCache"><img src="https://img.shields.io/badge/Github-Code-blue"></a></p>""" | ||
|
||
# show examples below | ||
|
||
|
||
with gr.Blocks() as demo: | ||
gr.Markdown(title) | ||
gr.Markdown(description) | ||
gr.Markdown(article) | ||
with gr.Row(): | ||
with gr.Column(): | ||
inp0 = gr.Image(source="upload", type="filepath") | ||
inp1 = gr.Textbox(label="Question") | ||
with gr.Column(): | ||
out0 = gr.Textbox() | ||
out1 = gr.Textbox(label="is hit") | ||
btn = gr.Button("Submit") | ||
btn.click(fn=pipeline, inputs=[inp0, inp1], outputs=[out0, out1]) | ||
|
||
demo.launch(share=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from gptcache.adapter.adapter import adapt | ||
from gptcache.utils.error import CacheError | ||
from gptcache.manager.scalar_data.base import DataType, Question, Answer | ||
|
||
from argparse import Namespace | ||
|
||
from minigpt4.common.config import Config | ||
from minigpt4.common.registry import registry | ||
from minigpt4.conversation.conversation import Chat, CONV_VISION | ||
|
||
# pylint: disable=wildcard-import | ||
# imports modules for registration | ||
from minigpt4.datasets.builders import * | ||
from minigpt4.models import * | ||
from minigpt4.processors import * | ||
from minigpt4.runners import * | ||
from minigpt4.tasks import * | ||
|
||
|
||
class MiniGPT4: # pragma: no cover | ||
"""MiniGPT4 Wrapper | ||
Example: | ||
.. code-block:: python | ||
from gptcache import cache | ||
from gptcache.processor.pre import get_image_question | ||
from gptcache.adapter.minigpt4 import MiniGPT4 | ||
# init gptcache | ||
cache.init(pre_embedding_func=get_image_question) | ||
# run with gptcache | ||
pipe = MiniGPT4.from_pretrained(cfg_path='eval_configs/minigpt4_eval.yaml', gpu_id=3, options=None) | ||
question = "Which city is this photo taken?" | ||
image = "./merlion.png" | ||
answer = pipe(image, question) | ||
""" | ||
def __init__(self, chat, return_hit): | ||
self.chat = chat | ||
self.return_hit = return_hit | ||
|
||
@classmethod | ||
def from_pretrained(cls, cfg_path, gpu_id=0, options=None, return_hit=False): | ||
args = Namespace(cfg_path=cfg_path, gpu_id=gpu_id, options=options) | ||
cfg = Config(args) | ||
model_config = cfg.model_cfg | ||
model_config.device_8bit = args.gpu_id | ||
model_cls = registry.get_model_class(model_config.arch) | ||
model = model_cls.from_config(model_config).to("cuda:{}".format(args.gpu_id)) | ||
|
||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train | ||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) | ||
chat = Chat(model, vis_processor, device="cuda:{}".format(args.gpu_id)) | ||
return cls(chat, return_hit) | ||
|
||
def llm_handler(self, image, question): | ||
chat_state = CONV_VISION.copy() | ||
img_list = [] | ||
try: | ||
self.chat.upload_img(image, chat_state, img_list) | ||
self.chat.ask(question, chat_state) | ||
answer = self.chat.answer(conv=chat_state, img_list=img_list)[0] | ||
return answer if not self.return_hit else answer, False | ||
except Exception as e: | ||
raise CacheError("minigpt4 error") from e | ||
|
||
def __call__(self, image, question, *args, **kwargs): | ||
cache_context = {"deps": [ | ||
{"name": "text", "data": question, "dep_type": DataType.STR}, | ||
{"name": "image", "data": image, "dep_type": DataType.STR}, | ||
]} | ||
|
||
def cache_data_convert(cache_data): | ||
return cache_data if not self.return_hit else cache_data, True | ||
|
||
def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument | ||
question_data = Question.from_dict({ | ||
"content": "pre_embedding_data", | ||
"deps": [ | ||
{"name": "text", "data": kwargs["question"], "dep_type": DataType.STR}, | ||
{"name": "image", "data": kwargs["image"], "dep_type": DataType.STR}, | ||
] | ||
}) | ||
llm_data_cache = llm_data if not self.return_hit else llm_data[0] | ||
update_cache_func(Answer(llm_data_cache, DataType.STR), question=question_data) | ||
return llm_data | ||
|
||
return adapt( | ||
self.llm_handler, cache_data_convert, update_cache_callback, image=image, question=question, cache_context=cache_context, *args, **kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters