Skip to content

Commit

Permalink
adding async operations for embedding generation and thread safety fo…
Browse files Browse the repository at this point in the history
…r the blip2 model api calls
  • Loading branch information
IshmeetMehta committed Dec 18, 2024
1 parent ec4c0bb commit 523dea4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

DISTANCE_FUNCTION = "cosine"
NUM_LEAVES_VALUE = int(os.environ.get("NUM_LEAVES_VALUE"))
# max_workers_value = int(os.environ.get("MAX_WORKERS_VALUE"))
max_workers_value = 32

embedding_columns = {
"text": "text_embeddings",
Expand Down Expand Up @@ -74,6 +76,7 @@
catalog_db,
catalog_table,
processed_data_path,
max_workers_value,
)
)#closing ayncio.run here
logger.info("ETL job has been completed successfully ...")
Expand Down
19 changes: 12 additions & 7 deletions use-cases/rag-pipeline/alloy-db-setup/src/create_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def create_database(database, new_database):
logger.info("Connector closed")


async def create_and_populate_table(database, table_name, processed_data_path):
async def create_and_populate_table(database, table_name, processed_data_path, max_workers_value):
"""Creates and populates a table in PostgreSQL using pandas and sqlalchemy."""

try:
Expand All @@ -120,7 +120,7 @@ async def create_and_populate_table(database, table_name, processed_data_path):

# 2. Transform
logger.info(f"Starting embedding generation...")
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(max_workers=max_workers_value) as executor:
loop = asyncio.get_event_loop()

# Create all embeddings tasks concurrently
Expand All @@ -138,13 +138,18 @@ async def create_and_populate_table(database, table_name, processed_data_path):
]

# Gather results concurrently
df["multimodal_embeddings"], df["text_embeddings"], df["image_embeddings"] = await asyncio.gather(
asyncio.gather(*multimodal_tasks),
asyncio.gather(*text_tasks),
asyncio.gather(*image_tasks),

multimodal_results, text_results, image_results = await asyncio.gather(
asyncio.gather(*multimodal_tasks),
asyncio.gather(*text_tasks),
asyncio.gather(*image_tasks),
)

df["multimodal_embeddings"] = multimodal_results
df["text_embeddings"] = text_results
df["image_embeddings"] = image_results

logger.info(f"Embedding generation task is now complete")
logger.info(f"Embedding generation task is now complete")

# 3. Load (this part remains synchronous for now)
#TODO: Check if alloyDb allows async operations
Expand Down

0 comments on commit 523dea4

Please sign in to comment.