Skip to content

Commit

Permalink
adding async operations for embedding generation and thread safety fo…
Browse files Browse the repository at this point in the history
…r the blip2 model api calls
  • Loading branch information
IshmeetMehta committed Dec 18, 2024
1 parent 63981d2 commit 186a0d4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 62 deletions.
58 changes: 27 additions & 31 deletions use-cases/rag-pipeline/alloy-db-setup/src/create_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import os
import pandas as pd
import sqlalchemy
import asyncio
from concurrent.futures import ThreadPoolExecutor

from google.cloud.alloydb.connector import Connector
from pgvector.sqlalchemy import Vector
Expand All @@ -36,7 +38,6 @@
)
logger.setLevel(new_log_level)


def create_database(database, new_database):
"""Creates a new database in AlloyDB and enables necessary extensions."""
try:
Expand Down Expand Up @@ -105,42 +106,44 @@ def create_database(database, new_database):
logger.info("Connector closed")


def create_and_populate_table(database, table_name, processed_data_path):
async def create_and_populate_table(database, table_name, processed_data_path):
"""Creates and populates a table in PostgreSQL using pandas and sqlalchemy."""

try:
# 1. Extract
# 1. Extract the data
df = pd.read_csv(processed_data_path)
logger.info(f"Input df shape: {df.shape}")

# Dropping products with image_uri as NaN
# Drop the products with image_uri as NaN
df.dropna(subset=["image_uri"], inplace=True)
logger.info(f"resulting df shape: {df.shape}")

# 2. Transform
logger.info(f"Starting embedding generation...")
# 2. Transform
logger.info(f"Starting multimodal embedding generation...")
df["multimodal_embeddings"] = df.apply(
lambda row: get_emb.get_embeddings(row["image_uri"], row["Description"]),
axis=1,
)
logger.info(f"Multimodal embedding generation completed")

logger.info(f"Starting text embedding generation...")
df["text_embeddings"] = df.apply(
lambda row: get_emb.get_embeddings(None, row["Description"]), axis=1
)
logger.info(f"Text embedding generation completed")

logger.info(f"Starting image embedding generation...")
df["image_embeddings"] = df.apply(
lambda row: get_emb.get_embeddings(row["image_uri"], None), axis=1
)
logger.info(f"Image embedding generation completed")
with ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(executor, get_emb.get_embeddings, row["image_uri"], row["Description"])
for _, row in df.iterrows()
]
df["multimodal_embeddings"] = await asyncio.gather(*tasks)

tasks = [
loop.run_in_executor(executor, get_emb.get_embeddings, None, row["Description"])
for _, row in df.iterrows()
]
df["text_embeddings"] = await asyncio.gather(*tasks)

tasks = [
loop.run_in_executor(executor, get_emb.get_embeddings, row["image_uri"], None)
for _, row in df.iterrows()
]
df["image_embeddings"] = await asyncio.gather(*tasks)

logger.info(f"Embedding generation task is now complete")

# 3. Load
# 3. Load (this part remains synchronous for now)
#TODO: Check if alloyDb allows async operations
with Connector() as connector:
engine = alloydb_connect.init_connection_pool(connector, database)
with engine.begin() as connection:
Expand Down Expand Up @@ -172,13 +175,6 @@ def create_and_populate_table(database, table_name, processed_data_path):
logging.error(
f"An unexpected error occurred while creating and populating the table: {e}"
)
finally:
if connection:
connection.close()
logger.info(f"DB: {database} Connection closed")
if connector:
connector.close()
logger.info("Connector closed")


# Create an Scann index on the table with embedding column and cosine distance
Expand Down
61 changes: 30 additions & 31 deletions use-cases/rag-pipeline/alloy-db-setup/src/get_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import requests
import json
from threading import Lock

# Define the API Endpoints
TEXT_API_ENDPOINT = os.environ.get("TEXT_EMBEDDING_ENDPOINT")
Expand All @@ -32,10 +33,8 @@
try:
# Convert the string to a logging level constant
numeric_level = getattr(logging, new_log_level)

# Set the level for the root logger
logging.getLogger().setLevel(numeric_level)

logger.setLevel(new_log_level)
logger.info(
"Log level set to '%s' via LOG_LEVEL environment variable", new_log_level
)
Expand All @@ -48,6 +47,8 @@
"Invalid LOG_LEVEL value: '%s'. Using default log level.", new_log_level
)

# Create a lock for thread safety
lock = Lock()

def get_image_embeddings(image_uri):
"""
Expand All @@ -61,36 +62,32 @@ def get_image_embeddings(image_uri):
Raises:
requests.exceptions.HTTPError: If there is an error fetching the image embeddings
or the API returns an invalid response.
or the API returns an invalid response.
"""
try:
response = requests.post(
IMAGE_API_ENDPOINT,
json={"image_uri": image_uri},
headers={"Content-Type": "application/json"},
timeout=100,
)

# This will raise an HTTPError for bad responses (4xx or 5xx)
response.raise_for_status()
with lock:
response = requests.post(
IMAGE_API_ENDPOINT,
json={"image_uri": image_uri},
headers={"Content-Type": "application/json"},
timeout=100,
)

response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
image_embeddings = response.json()["image_embeds"]
return image_embeddings

except requests.exceptions.HTTPError as e:
# Reraise HTTPError for better error handling
logger.exception("Error fetching image embedding: %s", e)
raise

except requests.exceptions.RequestException as e:
# For other request errors, re-raise as an HTTPError
logger.exception("Invalid response from image embedding API: %s", e)
raise requests.exceptions.HTTPError(
"Error fetching image embedding", response=requests.Response()
) from e

except (ValueError, TypeError) as e:
# Handle potential JSON decoding errors
logger.exception(
"Not able to decode received json from image embedding API: %s", e
)
Expand All @@ -112,15 +109,16 @@ def get_multimodal_embeddings(image_uri, desc):
Raises:
requests.exceptions.HTTPError: If there is an error fetching the multimodal embeddings
or the API returns an invalid response.
or the API returns an invalid response.
"""
try:
response = requests.post(
MULTIMODAL_API_ENDPOINT,
json={"image_uri": image_uri, "caption": desc},
headers={"Content-Type": "application/json"},
timeout=100,
)
with lock:
response = requests.post(
MULTIMODAL_API_ENDPOINT,
json={"image_uri": image_uri, "caption": desc},
headers={"Content-Type": "application/json"},
timeout=100,
)

response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
return response.json()["multimodal_embeds"]
Expand Down Expand Up @@ -157,18 +155,18 @@ def get_text_embeddings(text):
Raises:
requests.exceptions.HTTPError: If there is an error fetching the text embeddings
or the API returns an invalid response.
or the API returns an invalid response.
"""
try:
response = requests.post(
TEXT_API_ENDPOINT,
json={"caption": text},
headers={"Content-Type": "application/json"},
timeout=100,
)
with lock:
response = requests.post(
TEXT_API_ENDPOINT,
json={"caption": text},
headers={"Content-Type": "application/json"},
timeout=100,
)

response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)

return response.json()["text_embeds"]

except requests.exceptions.HTTPError as e:
Expand Down Expand Up @@ -218,3 +216,4 @@ def get_embeddings(image_uri=None, text=None):
"Missing input. Provide a textual product description and/or image_uri to generate embeddings"
)
return None

0 comments on commit 186a0d4

Please sign in to comment.