Skip to content

Commit

Permalink
adding async operations for embedding generation and thread safety fo…
Browse files Browse the repository at this point in the history
…r the blip2 model api calls
  • Loading branch information
IshmeetMehta committed Dec 18, 2024
1 parent e2a6450 commit ec4c0bb
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions use-cases/rag-pipeline/alloy-db-setup/src/get_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
"Invalid LOG_LEVEL value: '%s'. Using default log level.", new_log_level
)

# Create a lock for thread safety
lock = Lock()

def get_image_embeddings(image_uri):
"""
Fetches image embeddings from an image embedding API.
Expand All @@ -65,13 +62,12 @@ def get_image_embeddings(image_uri):
or the API returns an invalid response.
"""
try:
with lock:
response = requests.post(
IMAGE_API_ENDPOINT,
json={"image_uri": image_uri},
headers={"Content-Type": "application/json"},
timeout=100,
)
response = requests.post(
IMAGE_API_ENDPOINT,
json={"image_uri": image_uri},
headers={"Content-Type": "application/json"},
timeout=100,
)

response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
image_embeddings = response.json()["image_embeds"]
Expand Down Expand Up @@ -112,13 +108,12 @@ def get_multimodal_embeddings(image_uri, desc):
or the API returns an invalid response.
"""
try:
with lock:
response = requests.post(
MULTIMODAL_API_ENDPOINT,
json={"image_uri": image_uri, "caption": desc},
headers={"Content-Type": "application/json"},
timeout=100,
)
response = requests.post(
MULTIMODAL_API_ENDPOINT,
json={"image_uri": image_uri, "caption": desc},
headers={"Content-Type": "application/json"},
timeout=100,
)

response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
return response.json()["multimodal_embeds"]
Expand Down Expand Up @@ -158,13 +153,12 @@ def get_text_embeddings(text):
or the API returns an invalid response.
"""
try:
with lock:
response = requests.post(
TEXT_API_ENDPOINT,
json={"caption": text},
headers={"Content-Type": "application/json"},
timeout=100,
)
response = requests.post(
TEXT_API_ENDPOINT,
json={"caption": text},
headers={"Content-Type": "application/json"},
timeout=100,
)

response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
return response.json()["text_embeds"]
Expand Down

0 comments on commit ec4c0bb

Please sign in to comment.