Skip to content

Commit

Permalink
using gpt4o-mini most places
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Sep 14, 2024
1 parent b417cf6 commit e80d56a
Showing 1 changed file with 17 additions and 26 deletions.
43 changes: 17 additions & 26 deletions backend/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from models.trend import TrendEnum
from utils.memories.facts import get_prompt_facts

llm = ChatOpenAI(model='gpt-4o')
llm_mini = ChatOpenAI(model='gpt-4o-mini')
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
parser = PydanticOutputParser(pydantic_object=Structured)
llm_with_parser = llm.with_structured_output(Structured)

encoding = tiktoken.encoding_for_model('gpt-4')

Expand Down Expand Up @@ -58,7 +57,7 @@ def should_discard_memory(transcript: str) -> bool:
{format_instructions}'''.replace(' ', '').strip()
])
chain = prompt | llm | parser
chain = prompt | llm_mini | parser
try:
response: DiscardMemory = chain.invoke({
'transcript': transcript.strip(),
Expand Down Expand Up @@ -87,7 +86,7 @@ def get_transcript_structure(transcript: str, started_at: datetime, language_cod
{format_instructions}'''.replace(' ', '').strip()
)])
chain = prompt | llm | parser
chain = prompt | ChatOpenAI(model='gpt-4o') | parser

response = chain.invoke({
'transcript': transcript.strip(),
Expand Down Expand Up @@ -120,18 +119,13 @@ def get_plugin_result(transcript: str, plugin: Plugin) -> str:
Make sure to be concise and clear.
'''

response = llm.invoke(prompt)
response = llm_mini.invoke(prompt)
content = response.content.replace('```json', '').replace('```', '')
if len(content) < 5:
return ''
return content


# *******************************************
# ************* POSTPROCESSING **************
# *******************************************


# **************************************
# ************* OPENGLASS **************
# **************************************
Expand All @@ -148,7 +142,7 @@ def summarize_open_glass(photos: List[MemoryPhoto]) -> Structured:
Photos Descriptions: ```{photos_str}```
'''.replace(' ', '').strip()
return llm_with_parser.invoke(prompt)
return llm_mini.with_structured_output(Structured).invoke(prompt)


# **************************************************
Expand All @@ -165,8 +159,7 @@ def summarize_experience_text(text: str) -> Structured:
Text: ```{text}```
'''.replace(' ', '').strip()
# return groq_llm_with_parser.invoke(prompt)
return llm_with_parser.invoke(prompt)
return llm_mini.with_structured_output(Structured).invoke(prompt)


def get_memory_summary(uid: str, memories: List[Memory]) -> str:
Expand All @@ -190,7 +183,7 @@ def get_memory_summary(uid: str, memories: List[Memory]) -> str:
```
""".replace(' ', '').strip()
# print(prompt)
return llm.invoke(prompt).content
return llm_mini.invoke(prompt).content


def generate_embedding(content: str) -> List[float]:
Expand Down Expand Up @@ -231,7 +224,7 @@ def initial_chat_message(uid: str, plugin: Optional[Plugin] = None) -> str:
Output your response in plain text, without markdown.
'''
prompt = prompt.replace(' ', '').strip()
return llm.invoke(prompt).content
return llm_mini.invoke(prompt).content


# *********************************************
Expand Down Expand Up @@ -259,7 +252,7 @@ def requires_context(messages: List[Message]) -> bool:
Conversation History:
{Message.get_messages_as_string(messages)}
'''
with_parser = llm.with_structured_output(RequiresContext)
with_parser = llm_mini.with_structured_output(RequiresContext)
response: RequiresContext = with_parser.invoke(prompt)
return response.value

Expand All @@ -278,7 +271,7 @@ def retrieve_context_topics(messages: List[Message]) -> List[str]:
Conversation:
{Message.get_messages_as_string(messages)}
'''.replace(' ', '').strip()
with_parser = llm.with_structured_output(TopicsContext)
with_parser = llm_mini.with_structured_output(TopicsContext)
response: TopicsContext = with_parser.invoke(prompt)
topics = list(map(lambda x: str(x.value).capitalize(), response.topics))
return topics
Expand All @@ -299,7 +292,7 @@ def retrieve_context_dates(messages: List[Message]) -> List[datetime]:
Conversation:
{Message.get_messages_as_string(messages)}
'''.replace(' ', '').strip()
with_parser = llm.with_structured_output(DatesContext)
with_parser = llm_mini.with_structured_output(DatesContext)
response: DatesContext = with_parser.invoke(prompt)
return response.dates_range

Expand All @@ -309,8 +302,6 @@ class SummaryOutput(BaseModel):


def chunk_extraction(segments: List[TranscriptSegment], topics: List[str]) -> str:
_chat = ChatOpenAI(model="gpt-4o-mini")

content = TranscriptSegment.segments_as_string(segments)
prompt = f'''
You are an experienced detective, your task is to extract the key points of the conversation related to the topics you were provided.
Expand All @@ -324,7 +315,7 @@ def chunk_extraction(segments: List[TranscriptSegment], topics: List[str]) -> st
Topics: {topics}
'''
with_parser = _chat.with_structured_output(SummaryOutput)
with_parser = llm_mini.with_structured_output(SummaryOutput)
response: SummaryOutput = with_parser.invoke(prompt)
return response.summary

Expand Down Expand Up @@ -359,7 +350,7 @@ def qa_rag(uid: str, context: str, messages: List[Message], plugin: Optional[Plu
Answer:
""".replace(' ', '').strip()
print(prompt)
return llm.invoke(prompt).content
return llm_mini.invoke(prompt).content


# **************************************************
Expand All @@ -382,7 +373,7 @@ def retrieve_memory_context_params(memory: Memory) -> List[str]:
'''.replace(' ', '').strip()

try:
with_parser = llm.with_structured_output(TopicsContext)
with_parser = llm_mini.with_structured_output(TopicsContext)
response: TopicsContext = with_parser.invoke(prompt)
return response.topics
except Exception as e:
Expand Down Expand Up @@ -415,7 +406,7 @@ def obtain_emotional_message(uid: str, memory: Memory, context: str, emotion: st
{context}
```
""".replace(' ', '').strip()
return llm.invoke(prompt).content
return llm_mini.invoke(prompt).content


# **********************************
Expand Down Expand Up @@ -464,7 +455,7 @@ def new_facts_extractor(uid: str, segments: List[TranscriptSegment]) -> List[Fac
'''.replace(' ', '').strip()

try:
with_parser = llm.with_structured_output(Facts)
with_parser = llm_mini.with_structured_output(Facts)
response: Facts = with_parser.invoke(prompt)
# for fact in response:
# fact.content = fact.content.replace(user_name, '').replace('The User', '').replace('User', '').strip()
Expand Down Expand Up @@ -511,7 +502,7 @@ def trends_extractor(memory: Memory) -> List[Item]:
{transcript}
'''.replace(' ', '').strip()
try:
with_parser = llm.with_structured_output(ExpectedOutput)
with_parser = llm_mini.with_structured_output(ExpectedOutput)
response: ExpectedOutput = with_parser.invoke(prompt)
return response.items
except Exception as e:
Expand Down

0 comments on commit e80d56a

Please sign in to comment.