From e2a645026634010354ad4c28df88c0640bde0521 Mon Sep 17 00:00:00 2001 From: Ishmeet Mehta Date: Wed, 18 Dec 2024 15:59:56 +0000 Subject: [PATCH] adding async operations for embedding generation and thread safety for the blip2 model api calls --- .../alloy-db-setup/src/alloy-db-setup-job.py | 12 +++++++-- .../alloy-db-setup/src/create_catalog.py | 26 +++++++++++-------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/use-cases/rag-pipeline/alloy-db-setup/src/alloy-db-setup-job.py b/use-cases/rag-pipeline/alloy-db-setup/src/alloy-db-setup-job.py index 817a3752..e65afe30 100644 --- a/use-cases/rag-pipeline/alloy-db-setup/src/alloy-db-setup-job.py +++ b/use-cases/rag-pipeline/alloy-db-setup/src/alloy-db-setup-job.py @@ -17,6 +17,9 @@ import logging import logging.config import os +import asyncio + +# Environment variables # Master_product_catalog.csv PROCESSED_DATA_BUCKET = os.environ.get("PROCESSED_DATA_BUCKET") @@ -63,14 +66,18 @@ database_name, catalog_db, ) + logger.info("DB product_catalog in has been created successfully ...") - # ETL + # ETL Run logger.info("ETL job to create table and generate embeddings in progress ...") - create_catalog.create_and_populate_table( + asyncio.run(create_catalog.create_and_populate_table( catalog_db, catalog_table, processed_data_path, ) + )#closing ayncio.run here + logger.info("ETL job has been completed successfully ...") + # Create Indexes for all embedding columns(text, image and multimodal) logger.info("Create SCaNN indexes in progress ...") @@ -85,6 +92,7 @@ DISTANCE_FUNCTION, NUM_LEAVES_VALUE, ) + logger.info("SCaNN indexes have been created successfully ...") except Exception as e: logger.error(f"An unexpected error occurred during catalog onboarding: {e}") raise diff --git a/use-cases/rag-pipeline/alloy-db-setup/src/create_catalog.py b/use-cases/rag-pipeline/alloy-db-setup/src/create_catalog.py index c55a1d70..3059c2f9 100644 --- a/use-cases/rag-pipeline/alloy-db-setup/src/create_catalog.py +++ b/use-cases/rag-pipeline/alloy-db-setup/src/create_catalog.py @@ -122,23 +122,27 @@ async def create_and_populate_table(database, table_name, processed_data_path): logger.info(f"Starting embedding generation...") with ThreadPoolExecutor() as executor: loop = asyncio.get_event_loop() - tasks = [ - loop.run_in_executor(executor, get_emb.get_embeddings, row["image_uri"], row["Description"]) + + # Create all embeddings tasks concurrently + multimodal_tasks = [ + loop.run_in_executor(executor, get_emb.get_embeddings, row["image_uri"], row["Description"]) for _, row in df.iterrows() ] - df["multimodal_embeddings"] = await asyncio.gather(*tasks) - - tasks = [ - loop.run_in_executor(executor, get_emb.get_embeddings, None, row["Description"]) + text_tasks = [ + loop.run_in_executor(executor, get_emb.get_embeddings, None, row["Description"]) for _, row in df.iterrows() ] - df["text_embeddings"] = await asyncio.gather(*tasks) - - tasks = [ - loop.run_in_executor(executor, get_emb.get_embeddings, row["image_uri"], None) + image_tasks = [ + loop.run_in_executor(executor, get_emb.get_embeddings, row["image_uri"], None) for _, row in df.iterrows() ] - df["image_embeddings"] = await asyncio.gather(*tasks) + + # Gather results concurrently + df["multimodal_embeddings"], df["text_embeddings"], df["image_embeddings"] = await asyncio.gather( + asyncio.gather(*multimodal_tasks), + asyncio.gather(*text_tasks), + asyncio.gather(*image_tasks), + ) logger.info(f"Embedding generation task is now complete")