Skip to content

Commit

Permalink
graph logic handles omi questions
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Nov 11, 2024
1 parent 30a114f commit d49d58e
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 122 deletions.
168 changes: 103 additions & 65 deletions backend/utils/retrieval/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import datetime
import os
import time
import uuid
from typing import List, Optional, Tuple

Expand All @@ -10,21 +8,29 @@
from langgraph.graph import START, StateGraph
from typing_extensions import TypedDict, Literal

from utils.other.endpoints import timeit

# os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS')

import database.chat as chat_db
import database.memories as memories_db
from database.redis_db import get_filter_category_items
from database.vector_db import query_vectors_by_metadata
from models.chat import Message
from models.memory import Memory
from models.plugin import Plugin
from utils.llm import requires_context, answer_simple_message, retrieve_context_dates, qa_rag, \
select_structured_filters, extract_question_from_conversation, generate_embedding
from utils.llm import (
answer_omi_question,
requires_context,
answer_simple_message,
retrieve_context_dates,
qa_rag,
retrieve_is_an_omi_question,
select_structured_filters,
extract_question_from_conversation,
generate_embedding,
)
from utils.other.endpoints import timeit
from utils.plugins import get_github_docs_content

# os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS')

model = ChatOpenAI(model='gpt-4o-mini')
model = ChatOpenAI(model="gpt-4o-mini")


class StructuredFilters(TypedDict):
Expand Down Expand Up @@ -53,82 +59,101 @@ class GraphState(TypedDict):
answer: Optional[str]


def determine_conversation_type(s: GraphState) -> Literal[
"no_context_conversation", "context_dependent_conversation"]:
requires = requires_context(s.get('messages', []))
def determine_conversation_type(
s: GraphState,
) -> Literal["no_context_conversation", "context_dependent_conversation", "no_context_omi_question"]:
is_omi_question = retrieve_is_an_omi_question(s.get("messages", []))
if is_omi_question:
return "no_context_omi_question"

requires = requires_context(s.get("messages", []))

if requires:
return 'context_dependent_conversation'
return 'no_context_conversation'
return "context_dependent_conversation"
return "no_context_conversation"


def no_context_conversation(state: GraphState):
print('no_context_conversation node')
return {'answer': answer_simple_message(state.get('uid'), state.get('messages'))}
print("no_context_conversation node")
return {"answer": answer_simple_message(state.get("uid"), state.get("messages"))}


def no_context_omi_question(state: GraphState):
print("no_context_omi_question node")
context: dict = get_github_docs_content()
context_str = 'Documentation:\n\n'.join([f'{k}:\n {v}' for k, v in context.items()])
answer = answer_omi_question(state.get("messages", []), context_str)
return {'answer': answer}


def context_dependent_conversation(state: GraphState):
question = extract_question_from_conversation(state.get('messages', []))
print('context_dependent_conversation parsed question:', question)
return {'parsed_question': question}
question = extract_question_from_conversation(state.get("messages", []))
print("context_dependent_conversation parsed question:", question)
return {"parsed_question": question}


# !! include a question extractor? node?


def retrieve_topics_filters(state: GraphState):
print('retrieve_topics_filters')
print("retrieve_topics_filters")
filters = {
'people': get_filter_category_items(state.get('uid'), 'people'),
'topics': get_filter_category_items(state.get('uid'), 'topics'),
'entities': get_filter_category_items(state.get('uid'), 'entities'),
"people": get_filter_category_items(state.get("uid"), "people"),
"topics": get_filter_category_items(state.get("uid"), "topics"),
"entities": get_filter_category_items(state.get("uid"), "entities"),
# 'dates': get_filter_category_items(state.get('uid'), 'dates'),
}
result = select_structured_filters(state.get('parsed_question', ''), filters)
return {'filters': {
'topics': result.get('topics', []),
'people': result.get('people', []),
'entities': result.get('entities', []),
# 'dates': result.get('dates', []),
}}
result = select_structured_filters(state.get("parsed_question", ""), filters)
return {
"filters": {
"topics": result.get("topics", []),
"people": result.get("people", []),
"entities": result.get("entities", []),
# 'dates': result.get('dates', []),
}
}


def retrieve_date_filters(state: GraphState):
dates_range = retrieve_context_dates(state.get('messages', []))
dates_range = retrieve_context_dates(state.get("messages", []))
if dates_range and len(dates_range) == 2:
return {'date_filters': {'start': dates_range[0], 'end': dates_range[1]}}
return {'date_filters': {}}
return {"date_filters": {"start": dates_range[0], "end": dates_range[1]}}
return {"date_filters": {}}


def query_vectors(state: GraphState):
print('query_vectors')
date_filters = state.get('date_filters')
uid = state.get('uid')
vector = generate_embedding(state.get('parsed_question', '')) if state.get('parsed_question') else [0] * 3072
print('query_vectors vector:', vector[:5])
print("query_vectors")
date_filters = state.get("date_filters")
uid = state.get("uid")
vector = (
generate_embedding(state.get("parsed_question", ""))
if state.get("parsed_question")
else [0] * 3072
)
print("query_vectors vector:", vector[:5])
memories_id = query_vectors_by_metadata(
uid,
vector,
dates_filter=[date_filters.get('start'), date_filters.get('end')],
people=state.get('filters', {}).get('people', []),
topics=state.get('filters', {}).get('topics', []),
entities=state.get('filters', {}).get('entities', []),
dates=state.get('filters', {}).get('dates', []),
dates_filter=[date_filters.get("start"), date_filters.get("end")],
people=state.get("filters", {}).get("people", []),
topics=state.get("filters", {}).get("topics", []),
entities=state.get("filters", {}).get("entities", []),
dates=state.get("filters", {}).get("dates", []),
)
memories = memories_db.get_memories_by_id(uid, memories_id)
return {'memories_found': memories}
return {"memories_found": memories}


def qa_handler(state: GraphState):
uid = state.get('uid')
memories = state.get('memories_found', [])
uid = state.get("uid")
memories = state.get("memories_found", [])
response: str = qa_rag(
uid,
state.get('parsed_question'),
state.get("parsed_question"),
Memory.memories_to_string(memories, True),
state.get('plugin_selected')
state.get("plugin_selected"),
)
return {'answer': response}
return {"answer": response}


workflow = StateGraph(GraphState)
Expand All @@ -139,53 +164,66 @@ def qa_handler(state: GraphState):
)

workflow.add_node("no_context_conversation", no_context_conversation)
workflow.add_node("no_context_omi_question", no_context_omi_question)
workflow.add_node("context_dependent_conversation", context_dependent_conversation)

workflow.add_edge("no_context_conversation", END)

workflow.add_edge("no_context_omi_question", END)
workflow.add_edge("context_dependent_conversation", "retrieve_topics_filters")
workflow.add_edge("context_dependent_conversation", "retrieve_date_filters")

workflow.add_node("retrieve_topics_filters", retrieve_topics_filters)
workflow.add_node("retrieve_date_filters", retrieve_date_filters)

workflow.add_edge('retrieve_topics_filters', 'query_vectors')
workflow.add_edge('retrieve_date_filters', 'query_vectors')
workflow.add_edge("retrieve_topics_filters", "query_vectors")
workflow.add_edge("retrieve_date_filters", "query_vectors")

workflow.add_node('query_vectors', query_vectors)
workflow.add_node("query_vectors", query_vectors)

workflow.add_edge('query_vectors', 'qa_handler')
workflow.add_edge("query_vectors", "qa_handler")

workflow.add_node('qa_handler', qa_handler)
workflow.add_node("qa_handler", qa_handler)

workflow.add_edge('qa_handler', END)
workflow.add_edge("qa_handler", END)

checkpointer = MemorySaver()
graph = workflow.compile(checkpointer=checkpointer)


@timeit
def execute_graph_chat(uid: str, messages: List[Message]) -> Tuple[str, List[Memory]]:
result = graph.invoke({'uid': uid, 'messages': messages}, {"configurable": {"thread_id": str(uuid.uuid4())}})
return result.get('answer'), result.get('memories_found', [])
result = graph.invoke(
{"uid": uid, "messages": messages},
{"configurable": {"thread_id": str(uuid.uuid4())}},
)
return result.get("answer"), result.get("memories_found", [])


def _pretty_print_conversation(messages: List[Message]):
for msg in messages:
print(f'{msg.sender}: {msg.text}')
print(f"{msg.sender}: {msg.text}")


if __name__ == '__main__':
# uid = 'ccQJWj5mwhSY1dwjS1FPFBfKIXe2'
if __name__ == "__main__":
# graph.get_graph().draw_png("workflow.png")
uid = "ccQJWj5mwhSY1dwjS1FPFBfKIXe2"
# def _send_message(text: str, sender: str = 'human'):
# message = Message(
# id=str(uuid.uuid4()), text=text, created_at=datetime.datetime.now(datetime.timezone.utc), sender=sender,
# type='text'
# )
# chat_db.add_message(uid, message.dict())


graph.get_graph().draw_png('workflow.png')
messages = [
Message(
id=str(uuid.uuid4()),
text="How can I build a plugin?",
created_at=datetime.datetime.now(datetime.timezone.utc),
sender="human",
type="text",
)
]
result = execute_graph_chat(uid, messages)
print("result:", print(result))
# messages = list(reversed([Message(**msg) for msg in chat_db.get_messages(uid, limit=10)]))
# _pretty_print_conversation(messages)
# # print(messages[-1].text)
Expand Down
59 changes: 2 additions & 57 deletions backend/utils/retrieval/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from collections import Counter, defaultdict
from typing import List, Tuple

from database.memories import filter_memories_by_date, get_memories_by_id
from database.memories import get_memories_by_id
from database.vector_db import query_vectors
from models.chat import Message
from models.memory import Memory
from models.transcript_segment import TranscriptSegment
from utils.llm import requires_context, retrieve_context_topics, retrieve_context_dates, chunk_extraction, \
num_tokens_from_string, retrieve_memory_context_params
from utils.llm import chunk_extraction, num_tokens_from_string, retrieve_memory_context_params


def retrieve_for_topic(uid: str, topic: str, start_timestamp, end_timestamp, k: int, memories_id) -> List[str]:
Expand Down Expand Up @@ -56,59 +54,6 @@ def get_better_memory_chunk(memory: Memory, topics: List[str], context_data: dic
context_data[memory.id] = chunk


def retrieve_rag_context(
uid: str, prev_messages: List[Message], return_context_params: bool = False
) -> Tuple[str, List[Memory]]:
requires = requires_context(prev_messages)

if not requires:
return '', []

topics = retrieve_context_topics(prev_messages)
dates_range = retrieve_context_dates(prev_messages)
print('retrieve_rag_context', topics, dates_range)
if not topics and len(dates_range) != 2:
return '', []

if len(topics) > 5:
topics = topics[:5]

memories_id_to_topics = {}
memories = None
if topics:
memories_id_to_topics, memories = retrieve_memories_for_topics(uid, topics, dates_range)
id_counter = Counter(memory['id'] for memory in memories)
memories = sorted(memories, key=lambda x: id_counter[x['id']], reverse=True)

if not memories and len(dates_range) == 2:
memories_id_to_topics = {}
memories = filter_memories_by_date(uid, dates_range[0], dates_range[1])

memories = [Memory(**memory) for memory in memories]
if len(memories) > 10:
memories = memories[:10]

# not performing as expected
if memories_id_to_topics:
context_data = {}
threads = []
for memory in memories:
m_topics = memories_id_to_topics.get(memory.id, [])
t = threading.Thread(target=get_better_memory_chunk, args=(memory, m_topics, context_data))
threads.append(t)
[t.start() for t in threads]
[t.join() for t in threads]
memories = list(filter(lambda x: x.id in context_data, memories))
context_str = '\n\n---------------------\n\n'.join(context_data.values()).strip()
else:
context_str = Memory.memories_to_string(memories)

if return_context_params:
return context_str, (memories if context_str else []), topics, dates_range

return context_str, (memories if context_str else [])


def retrieve_rag_memory_context(uid: str, memory: Memory) -> Tuple[str, List[Memory]]:
topics = retrieve_memory_context_params(memory)
print('retrieve_memory_rag_context', topics)
Expand Down

0 comments on commit d49d58e

Please sign in to comment.