Skip to content

Commit

Permalink
Merge pull request #1381 from wxywb/mmrag
Browse files Browse the repository at this point in the history
Add app of multimodal_rag_with_milvus.
  • Loading branch information
jaelgu authored Jul 19, 2024
2 parents 0c02d6c + 2516015 commit 3721e11
Show file tree
Hide file tree
Showing 20 changed files with 1,930 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Fill out the following if using OpenAI service
API_KEY=**************

# Fill out the following if using Azure OpenAI service
AZURE_OPENAI_API_KEY=**************
AZURE_OPENAI_ENDPOINT=https://*******.com
AZURE_DEPLOYMENT=******-***-**
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[theme]
base = "dark"
primaryColor = "#4fc4f9"
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Multimodal RAG with Milvus 🖼️

<div style="text-align: center;">
<figure>
<img src="./pics/cir_demo.jpg" alt="Description of Image" width="700"/>
</figure>
</div>

This multi-modal RAG (Retrieval-Augmented Generation) demo showcases the integration of Milvus with [MagicLens](https://open-vision-language.github.io/MagicLens/) and [GPT-4o](https://openai.com/index/hello-gpt-4o/) for advanced image searching based on user instructions. Users can upload an image and edit instructions, which are processed by MagicLens's composed retrieval model to search for candidate images. GPT-4o then acts as a reranker, selecting the most suitable image and providing the rationale behind the choice. This powerful combination enables a seamless and intuitive image search experience.

## Quick Deploy

Follow these steps to quickly deploy the application locally:

### Preparation

> Prerequisites: Python 3.8 or higher
**1. Download Codes**
```bash
$ git clone <https://github.com/milvus-io/bootcamp.git>
$ cd bootcamp/bootcamp/tutorials/quickstart/app/multimodal_rag_with_milvus
```

**2. Set Environment**

- Install dependencies

```bash
$ pip install -r requirements.txt
```

- Set environment variables

Modify the environment file [.env](./.env) to change environment variables for either OpenAI or Azure OpenAI service, and only keep the variables relevant to the service chosen:

```bash
# Fill out and keep the following if using OpenAI service
API_KEY=**************

# Fill out and keep the following if using Azure OpenAI service
AZURE_OPENAI_API_KEY=**************
AZURE_OPENAI_ENDPOINT=https://*******.com
AZURE_DEPLOYMENT=******-***-**
```

**3. Prepare MagicLens Model** <br>

More detailed information can be found at <https://github.com/google-deepmind/magiclens>

- Setup

```bash
conda create --name magic_lens python=3.9
conda activate magic_lens
git clone https://github.com/google-research/scenic.git
cd scenic
pip install .
pip install -r scenic/projects/baselines/clip/requirements.txt
# you may need to install corresponding GPU version of jax following https://jax.readthedocs.io/en/latest/installation.html
# e.g.,
# # CUDA 12 installation
# Note: wheels only available on linux.
# pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# # CUDA 11 installation
# Note: wheels only available on linux.
# pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

- Model Download

```bash
cd .. # in main folder of demo.
# you may need to use `gcloud auth login` for access, any gmail account should work.
gsutil cp -R gs://gresearch/magiclens/models ./
```

**4. Prepare Data**

We are using a subset of https://github.com/hyp1231/AmazonReviews2023 which includes approximately 5000 images in 33 different categories, such as applicances, beauty and personal care, clothing, sports and outdoors, etc.<br>

Download image set by running [download_images.py](./download_images.py).
```bash
$ python download_images.py
```

Create a collection and load image data from the dataset to get the knowledge ready by running [index.py](./index.py).

```bash
$ python index.py
```

### Start Service

Run the Streamlit application:

```bash
$ streamlit run ui.py
```

There have some options you can set in `cfg.py`.


### Example Usage:

**Step 1:** Choose an image file to upload (JPEG format), and give user instruction as a text input.

<div style="text-align: center;">
<figure>
<img src="./pics/step1.jpg" alt="Description of Image" width="700"/>
</figure>
</div>

**Step 2:** Click on the 'Search' button to see top 100 candidate images generated based on both query image and user instruction.

<div style="text-align: center;">
<figure>
<img src="./pics/step2.jpg" alt="Description of Image" width="700"/>
</figure>
</div>

**Step 3:** Click on the 'Ask GPT' button to get the best item chosen by GPT-4o after reranking along with detailed explanation.

<div style="text-align: center;">
<figure>
<img src="./pics/step3.jpg" alt="Description of Image" width="700"/>
</figure>
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
All_Beauty
Amazon_Fashion
Appliances
Arts_Crafts_and_Sewing
Automotive
Baby_Products
Beauty_and_Personal_Care
Books
CDs_and_Vinyl
Cell_Phones_and_Accessories
Clothing_Shoes_and_Jewelry
Digital_Music
Electronics
Gift_Cards
Grocery_and_Gourmet_Food
Handmade_Products
Health_and_Household
Health_and_Personal_Care
Home_and_Kitchen
Industrial_and_Scientific
Kindle_Store
Magazine_Subscriptions
Musical_Instruments
Office_Products
Patio_Lawn_and_Garden
Pet_Supplies
Software
Sports_and_Outdoors
Subscription_Boxes
Tools_and_Home_Improvement
Toys_and_Games
Video_Games
Unknown
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class Config:
# Define a class named Config to hold configuration settings.
def __init__(self):
# Initialize method to set default values for configuration settings.
self.download_path = "./images"
# Set the path where images will be downloaded to "./images".
self.imgs_per_category = 300
# Define the number of images to download per category, set to 300.
self.milvus_uri = "milvus.db"
# Set the URI for the Milvus database, you can change to "http://localhost:19530" for a standard Milvus.
self.collection_name = "cir_demo_large"
# Define the name of the collection in the Milvus database, set to "cir_demo_large".
self.device = "gpu"
# Set the device to use for computations, in this case, "gpu", you can change it to "cpu".
self.model_type = "large"
# Specify the type of model to use, set to "large".
self.model_path = "./models/magic_lens_clip_large.pkl"
# Set the path to the model file, default is "./magic_lens_clip_large.pkl".
self.api_type = "openai"
# Define the type of API to use, set to "openai", you can change it to "azure_openai".
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from datasets import load_dataset
from cfg import Config
import os


def download_images():
config = Config()
with open("categories.txt") as fw:
lines = fw.readlines()
for line in lines:
l = line.strip()
meta_dataset = load_dataset(
"McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{l}", split="full"
)
for i in range(config.imgs_per_category):
if len(meta_dataset[i]["images"]["large"]) > 0:
img_name = meta_dataset[i]["images"]["large"][0]
basename = os.path.basename(img_name)
if os.path.exists(f"{config.download_path}/{basename}") is False:
os.system(
f"wget {img_name} -P {config.download_path} --no-check-certificate"
)


if __name__ == "__main__":
download_images()
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from datasets import load_dataset
from pymilvus import MilvusClient
import numpy as np
import os
from PIL import Image
import json
from magiclens.magiclens import MagicLensEmbeddding
from cfg import Config

from retrieve import Retriever

encoder = Retriever()


def insert_data():
config = Config()
image_folder = f"{config.download_path}" + "/{}"
client = MilvusClient(uri=config.milvus_uri)
client.create_collection(
collection_name=config.collection_name,
overwrite=True,
auto_id=True,
dimension=768,
enable_dynamic_field=True,
)
count = 0
with open("categories.txt") as fw:
lines = fw.readlines()
for line in lines:
l = line.strip()
meta_dataset = load_dataset(
"McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{l}", split="full"
)
for i in range(config.imgs_per_category):
if len(meta_dataset[i]["images"]["large"]) > 0:
print(count)
count = count + 1
img_name = meta_dataset[i]["images"]["large"][0]
name = os.path.basename(img_name)
if os.path.exists(image_folder.format(name)) is True:
feat = encoder.encode_query(image_folder.format(name), "")
spec = json.dumps(meta_dataset[i])
res = client.insert(
collection_name=config.collection_name,
data={
"vector": np.array(feat.flatten()),
"spec": spec,
"name": f"{l}_{i}",
},
)


insert_data()
Loading

0 comments on commit 3721e11

Please sign in to comment.