Skip to content

Commit

Permalink
1.4.5: improve generation performance
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Dec 13, 2022
1 parent 160e7a5 commit f6f21d3
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 95 deletions.
207 changes: 126 additions & 81 deletions langame/functions/services.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Pool

import time
import asyncio
from typing import List, Optional, Tuple, Any
from langame.messages import (
UNIMPLEMENTED_TOPICS_MESSAGES,
Expand All @@ -9,13 +9,112 @@
)
import pytz
from random import choice
from firebase_admin import firestore
from firebase_admin import firestore, initialize_app
from google.cloud.firestore import DocumentSnapshot, Client
from sentry_sdk import capture_exception
import logging
import datetime

utc = pytz.UTC
poll_interval = 0.1


def _generate(
i: int,
api_key_doc_id: str,
logger: logging.Logger,
topics: List[str],
fix_grammar: bool,
parallel_completions: int,
personas: Optional[List[str]],
profanity_threshold: str,
translated: bool,
) -> Tuple[Optional[dict], Optional[dict], Optional[dict]]:
try:
initialize_app()
# pylint: disable=W0703
except: pass
db: Client = firestore.client()
timeout = 60
start_time = time.time()
# format to human readable date time
logger.info(f"[{i}] Generating starter at {datetime.datetime.now(utc)}")
_, ref = db.collection("memes").add(
{
"state": "to-process",
"topics": topics,
"createdAt": firestore.SERVER_TIMESTAMP,
"disabled": True,
"tweet": False,
"shard": 0, # TODO: math.floor(math.random() * 1),
"fixGrammar": fix_grammar,
"parallelCompletions": parallel_completions,
"personas": personas if personas else [],
"profanityThreshold": profanity_threshold,
}
)

# poll until it's in state "processed" or "error", timeout after 1 minute
while True:
prompt_doc = db.collection("memes").document(ref.id).get()
data = prompt_doc.to_dict()
if data.get("state") == "processed" and data.get("content", None):
if translated and not data.get("translated", None):
continue
logger.info(f"[{i}] Generated starter in {time.time() - start_time}s")
return (
{
"id": ref.id,
**data,
},
{
"id": ref.id,
"createdAt": datetime.datetime.now(utc),
},
None,
)
if data.get("state") == "error":
logger.error(f"Failed to request starter for {api_key_doc_id}", exc_info=1)
error = data.get("error", "unknown error")
if error == "no-topics":
user_message = choice(UNIMPLEMENTED_TOPICS_MESSAGES)
elif error == "profane":
user_message = choice(PROFANITY_MESSAGES)
user_message = user_message.replace(
"[TOPICS]", f"\"{','.join(topics)}\""
)
else:
user_message = choice(FAILING_MESSAGES)
capture_exception(Exception(str(error)))
return (
None,
None,
{
"message": error,
"code": 500,
"status": "ERROR",
"user_message": "error while generating conversation starter",
},
)
if time.time() - start_time > (
# increase timeout in case of post processing stuff
timeout * 2
if fix_grammar or translated
else timeout * 1
):
capture_exception(Exception("timeout"))
return (
None,
None,
{
"message": "timeout",
"code": 500,
"status": "ERROR",
"user_message": "error while generating conversation starter",
},
)
time.sleep(poll_interval)


def request_starter_for_service(
api_key_doc: DocumentSnapshot,
Expand All @@ -27,7 +126,7 @@ def request_starter_for_service(
parallel_completions: int = 3,
profanity_threshold: str = "tolerant",
personas: Optional[List[str]] = None,
) -> Tuple[Optional[Any], Optional[Any]]:
) -> Tuple[Optional[dict], Optional[Any]]:
"""
Request a conversation starter from the API.
Args:
Expand Down Expand Up @@ -67,84 +166,30 @@ def request_starter_for_service(
if conversation_starters_history_docs.exists
else []
)
new_history = []
poll_interval = 0.1

def generate(i: int) -> Tuple[Optional[DocumentSnapshot], Optional[dict]]:
timeout = 60
start_time = time.time()
# format to human readable date time
logger.info(f"[{i}] Generating starter at {datetime.datetime.now(utc)}")
_, ref = db.collection("memes").add(
{
"state": "to-process",
"topics": topics,
"createdAt": firestore.SERVER_TIMESTAMP,
"disabled": True,
"tweet": False,
"shard": 0, # TODO: math.floor(math.random() * 1),
"fixGrammar": fix_grammar,
"parallelCompletions": parallel_completions,
"personas": personas if personas else [],
"profanityThreshold": profanity_threshold,
}
)
# generate in parallel for "limit"
with Pool(processes=limit) as pool:

# poll until it's in state "processed" or "error", timeout after 1 minute
while True:
prompt_doc = db.collection("memes").document(ref.id).get()
data = prompt_doc.to_dict()
if data.get("state") == "processed" and data.get("content", None):
if translated and not data.get("translated", None):
continue
new_history.append(
{
"id": ref.id,
"createdAt": datetime.datetime.now(utc),
}
responses = pool.starmap(
_generate,
[
(
i,
api_key_doc.id,
logger,
topics,
fix_grammar,
parallel_completions,
personas,
profanity_threshold,
translated,
)
logger.info(f"[{i}] Generated starter in {time.time() - start_time}s")
return prompt_doc, None
if data.get("state") == "error":
logger.error(
f"Failed to request starter for {api_key_doc.id}", exc_info=1
)
error = data.get("error", "unknown error")
if error == "no-topics":
user_message = choice(UNIMPLEMENTED_TOPICS_MESSAGES)
elif error == "profane":
user_message = choice(PROFANITY_MESSAGES)
user_message = user_message.replace(
"[TOPICS]", f"\"{','.join(topics)}\""
)
else:
user_message = choice(FAILING_MESSAGES)
capture_exception(Exception(str(error)))
return None, {
"message": error,
"code": 500,
"status": "ERROR",
"user_message": "error while generating conversation starter",
}
if time.time() - start_time > (
# increase timeout in case of post processing stuff
timeout * 2
if fix_grammar or translated
else timeout * 1
):
capture_exception(Exception("timeout"))
return None, {
"message": "timeout",
"code": 500,
"status": "ERROR",
"user_message": "error while generating conversation starter",
}
time.sleep(poll_interval)

# generate in parallel for "limit"
with ThreadPoolExecutor(limit) as executor:
responses = executor.map(generate, range(limit))
conversation_starters, errors = zip(*responses)
for i in range(limit)
],
)
# turn [(starter, history, error), ...] into starters, history, erros,
conversation_starters, new_history, errors = zip(*responses)
print(conversation_starters, new_history, errors)
# if any are errors, return the first error
if any(errors):
return None, [e for e in errors if e][0]
Expand All @@ -158,7 +203,7 @@ def generate(i: int) -> Tuple[Optional[DocumentSnapshot], Optional[dict]]:
conversation_starters_history_list,
)
)
+ new_history
+ list(new_history)
)
org_doc.reference.update(
{
Expand All @@ -171,4 +216,4 @@ def generate(i: int) -> Tuple[Optional[DocumentSnapshot], Optional[dict]]:
)

# Return the conversation starters
return conversation_starters, None
return list(conversation_starters), None
2 changes: 1 addition & 1 deletion langame/functions/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def test_request_starter_for_service(self):
buckets = [1, 3, 5, 7]
for limit in buckets:
start_time = time.time()
conversation_starters, error = await request_starter_for_service(
conversation_starters, error = request_starter_for_service(
api_key_doc=api_key_doc,
org_doc=org_doc,
topics=["biology", "symbiosis", "love"],
Expand Down
2 changes: 1 addition & 1 deletion run/collection/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def request():


profiler = Profiler()
REPEAT = 2
REPEAT = 3

for i in range(REPEAT):
profiler.start()
Expand Down
18 changes: 11 additions & 7 deletions run/collection/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
logger.addHandler(logging.StreamHandler())
logger.handlers[0].setFormatter(
logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%d-%b-%y %H:%M:%S"
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
)
)
db: Client = firestore.client()
Expand Down Expand Up @@ -103,7 +104,7 @@ def base():
)


async def create_starter():
def create_starter():
"""
foo
"""
Expand Down Expand Up @@ -245,18 +246,21 @@ async def create_starter():
results = []
topics = set()
for conversation_starter in conversation_starters:
d = conversation_starter.to_dict()
if not conversation_starter.get("id"):
continue
results.append(
{
"id": conversation_starter.id,
"id": conversation_starter.get("id"),
# merge "content" (original english version) with "translated" (multi-language version)
"conversation_starter": {
"en": d.get("content", ""),
**(d.get("translated", {}) if translated else {}),
"en": conversation_starter.get("content", ""),
**(
conversation_starter.get("translated", {}) if translated else {}
),
},
}
)
for topic in d.get("topics", []):
for topic in conversation_starter.get("topics", []):
topics.add(topic)
# TODO: return ID and let argument to say "I want different CS than these IDs" (semantically)
return (
Expand Down
4 changes: 2 additions & 2 deletions run/collection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def before_request():


@app.route("/v1/conversation/starter", methods=["POST"])
async def path_create_starter():
return await create_starter()
def path_create_starter():
return create_starter()


@app.route(BASE, methods=["GET"])
Expand Down
4 changes: 2 additions & 2 deletions run/collection/service.prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ spec:
spec:
containerConcurrency: 5
containers:
image: gcr.io/langame-86ac4/collection:1.2.1
image: gcr.io/langame-86ac4/collection:1.2.2
ports:
- containerPort: 8080
name: http1
env:
- name: SENTRY_RELEASE
value: "1.2.1"
value: "1.2.2"
- name: ENVIRONMENT
value: "production"
resources:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
name="langame",
packages=find_packages(),
include_package_data=True,
version="1.4.4",
version="1.4.5",
description="",
install_requires=[
"firebase_admin",
Expand Down

0 comments on commit f6f21d3

Please sign in to comment.