Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

created image search demo #1373

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Image Similarity Search with Milvus 🖼️

This demo implements an image similarity search application using Streamlit, Milvus, and a pre-trained ResNet model. Users can upload an image, crop it to focus on the region of interest, and search for similar images from a pre-built database.

## Features
- Upload and crop images to define the region of interest.
- Extract features using a pre-trained ResNet model.
- Search for similar images using Milvus for efficient similarity search.
- Display search results along with similarity scores.

## Code Structure
```text
image_search_with_milvus/
├── app.py # Main Streamlit application
├── insert.py # Script to download and unzip image data
├── milvus_utils.py # Milvus-related operations
├── encoder.py # Feature extraction and model loading
├── requirements.txt # List of dependencies
```

- app.py: The main Streamlit application file where the user interface is defined and the image similarity search is performed.
- insert.py: This script handles the downloading and unzipping of image data required for the application.
- milvus_utils.py: Includes functions for interacting with the Milvus database, such as inserting image embeddings and setting up Milvus client.
- encoder.py: Contains the FeatureExtractor class, which is responsible for extracting feature vectors from images using a pre-trained ResNet model.

## Quick Deploy

Follow these steps to quickly deploy the application locally:

### Installation

#### Prerequisites
- Python 3.8 or higher

#### Install Dependencies
```sh
pip install -r requirements.txt
```

#### Clone the Repository
```sh
git clone <https://github.com/milvus-io/bootcamp.git>
cd bootcamp/bootcamp/tutorials/quickstart/app/image_search_with_milvus
```

jaelgu marked this conversation as resolved.
Show resolved Hide resolved
### Dataset Preparation
We are using a diverse dataset for this demo, which includes approximately 200 categories with images of animals, objects, buildings, and more: <https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip>. <br>
We will load and unzip the dataset for further processing by running the `insert.py` file.
```sh
python3 insert.py
```

### Usage
#### Run the Streamlit application
```sh
streamlit run app.py
```
#### Steps:
<div style="text-align: center;">
<figure>
<img src="./pics/step1.png" alt="Description of Image" width="700"/>
<figcaption>Step 1: Choose an image file to upload (JPEG format).</figcaption>
</figure>
</div>

<div style="text-align: center;">
<figure>
<img src="./pics/step2_and_3.jpg" alt="Description of Image" width="700"/>
<figcaption>Step 2: Crop the image to focus on the region of interest.</figcaption>
<figcaption>Step 3: Set the desired number of top-k results to display using the slider.</figcaption>
</figure>
</div>

<div style="text-align: center;">
<figure>
<img src="./pics/step4.jpg" alt="Description of Image" width="700"/>
<figcaption>Step 4: View the search results along with similarity scores.</figcaption>
</figure>
</div>
89 changes: 89 additions & 0 deletions bootcamp/tutorials/quickstart/apps/image_search_with_milvus/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import streamlit as st
from streamlit_cropper import st_cropper
import streamlit_cropper
from PIL import Image

st.set_page_config(layout="wide")

from encoder import load_model
from milvus_utils import get_db


def _recommended_box2(img: Image, aspect_ratio: tuple = None) -> dict:
width, height = img.size
return {
"left": int(0),
"top": int(0),
"width": int(width - 2),
"height": int(height - 2),
}


streamlit_cropper._recommended_box = _recommended_box2

extractor = load_model("resnet34")
client = get_db()

# Logo
st.sidebar.image("./pics/Milvus_Logo_Official.png", width=200)

# Title
st.title("Image Similarity Search :frame_with_picture: ")

query_image = "temp.jpg"
cols = st.columns(5)

uploaded_file = st.sidebar.file_uploader("Choose an image...", type="jpeg")

if uploaded_file is not None:
with open("temp.jpg", "wb") as f:
f.write(uploaded_file.getbuffer())
# cropper
# Get a cropped image from the frontend
uploaded_img = Image.open(uploaded_file)
width, height = uploaded_img.size

new_width = 370
new_height = int((new_width / width) * height)
uploaded_img = uploaded_img.resize((new_width, new_height))

st.sidebar.text(
"Query Image",
help="Edit the bounding box to change the ROI (Region of Interest).",
)
with st.sidebar.empty():
cropped_img = st_cropper(
uploaded_img,
box_color="#4fc4f9",
realtime_update=True,
aspect_ratio=(16, 9),
)

show_distance = st.sidebar.toggle("Show Distance")

# top k value slider
value = st.sidebar.slider("Select top k results shown", 10, 100, 20, step=1)

@st.cache_resource
def get_image_embedding(image_path):
return extractor(image_path)

image_embedding = get_image_embedding(cropped_img)

results = client.search(
"image_embeddings",
data=[extractor(cropped_img)],
limit=value,
output_fields=["filename"],
search_params={"metric_type": "COSINE"},
)
search_results = results[0]

for i, info in enumerate(search_results):
img_info = info["entity"]
imgName = img_info["filename"]
score = info["distance"]
img = Image.open(imgName)
cols[i % 5].image(img, use_column_width=True)
if show_distance:
cols[i % 5].write(f"Score: {score:.3f}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import streamlit as st
import torch
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import os

# Create path to the model file (may subject to change depending on where you store the file)
documents_dir = os.path.expanduser("~/Documents")
folder_name = "streamlit_app_image"
MODEL_PATH = os.path.join(documents_dir, folder_name, "feature_extractor_model.pth")
# Ensure the directory exists
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)


class FeatureExtractor:
def __init__(self, modelname):
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
self.preprocess = create_transform(**config)

def save(self, path):
torch.save(self.model.state_dict(), path)

@staticmethod
def load(path, modelname):
model = FeatureExtractor(modelname)
model.model.load_state_dict(torch.load(path))
model.model.eval()
return model

def __call__(self, input):
input_image = input.convert("RGB")
input_image = self.preprocess(input_image)
input_tensor = input_image.unsqueeze(0)
with torch.no_grad():
output = self.model(input_tensor)
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()


@st.cache_resource
def load_model(modelname, path=MODEL_PATH):
if os.path.exists(path):
return FeatureExtractor.load(path, modelname)
else:
model = FeatureExtractor(modelname)
model.save(path)
return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import requests
import zipfile
import certifi
import os
from encoder import load_model
from PIL import Image


def download_file(url, dest):
response = requests.get(url, verify=certifi.where())
with open(dest, "wb") as f:
f.write(response.content)


# Download and unzip data if not already done
zip_path = "reverse_image_search.zip"
if not os.path.exists(zip_path):
url = "https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip"
download_file(url, zip_path)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(".")

extractor = load_model("resnet34")


def insert_embeddings(client):
global extractor
root = "./train"
for dirpath, foldername, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".JPEG"):
filepath = os.path.join(dirpath, filename)
img = Image.open(filepath)
image_embedding = extractor(img)
client.insert(
"image_embeddings",
{"vector": image_embedding, "filename": filepath},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import streamlit as st
from pymilvus import MilvusClient
import os
from PIL import Image
from insert import insert_embeddings


@st.cache_resource
def get_milvus_client(uri):
return MilvusClient(uri=uri)


@st.cache_resource
def get_db():
if not os.path.exists("example.db"):
client = get_milvus_client(uri="example.db")
client.create_collection(
collection_name="image_embeddings",
vector_field_name="vector",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
insert_embeddings(client)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

db_exists_check should not include the step of inserting embeddings.


else:
client = get_milvus_client(uri="example.db")
return client
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
streamlit
streamlit-cropper
torch
timm
Pillow
scikit-learn
pymilvus
certifi
requests
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Build RAG with Milvus

This demo shows you how to build a RAG (Retrieval-Augmented Generation) pipeline with Milvus.<br>

The RAG system combines a retrieval system with a generative model to generate new text based on a given prompt. The system first retrieves relevant documents from a corpus using Milvus, and then uses a generative model to generate new text based on the retrieved documents.

## Code Structure
```text
image_search_with_milvus/
├── app.py # Main Streamlit application
├── data_prep.py # Script to retrieve text data
├── milvus_utils.py # Milvus-related operations
├── encoder.py # Text embeddings generation
├── requirements.txt # List of dependencies
```

- app.py: The main Streamlit application file where the user interface is defined and the RAG chatbot is presented.
- data_prep.py: This script handles the retrieving text data required for the application.
- milvus_utils.py: Includes functions for interacting with the Milvus database, such as creating collection and retrieving search results.
- encoder.py: Converts text input into text embeddings for further use.

## Quick Deploy

Follow these steps to quickly deploy the application locally:

### Preparation

#### Dependencies and Environment
```sh
pip install -r requirements.txt
```
We will use Azure OpenAI as the LLM in this demo. You should prepare api key `AZURE_OPENAI_API_KEY` and endpoint `AZURE_OPENAI_ENDPOINT` as environment variables.
```sh
os.environ["AZURE_OPENAI_API_KEY"] = "***********"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://***********"
os.environ["AZURE_DEPLOYMENT"] = "****-***-**-*****"
```

#### Clone the Repository
```sh
git clone <https://github.com/milvus-io/bootcamp.git>
cd bootcamp/bootcamp/tutorials/quickstart/app/rag_search_with_milvus
```

### Usage
#### Run the Streamlit application
```sh
streamlit run app.py
```
#### Steps:
<div style="text-align: center;">
<figure>
<img src="./pics/step1.png" alt="Description of Image" width="700"/>
<figcaption>Step 1: Enter your question in the chat and click on 'submit' button.</figcaption>
</figure>
</div>

<div style="text-align: center;">
<figure>
<img src="./pics/step2.png" alt="Description of Image" width="700"/>
<figcaption>Step 2: Response generated by LLM based on the prompt.</figcaption>
</figure>
</div>

<div style="text-align: center;">
<figure>
<img src="./pics/step3.png" alt="Description of Image" width="700"/>
<figcaption>Step 3: Top 3 retrieved original quotes from text data are listed on the left along with their distances, indicating how related the quote and the prompt are. </figcaption>
</figure>
</div>
Loading