From d49d58e2c269609cab22aea96474c45cf53f033c Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Mon, 11 Nov 2024 12:18:49 -0800 Subject: [PATCH] graph logic handles omi questions --- backend/utils/retrieval/graph.py | 168 +++++++++++++++++++------------ backend/utils/retrieval/rag.py | 59 +---------- 2 files changed, 105 insertions(+), 122 deletions(-) diff --git a/backend/utils/retrieval/graph.py b/backend/utils/retrieval/graph.py index 59f4c4a9e..b23a762c4 100644 --- a/backend/utils/retrieval/graph.py +++ b/backend/utils/retrieval/graph.py @@ -1,6 +1,4 @@ import datetime -import os -import time import uuid from typing import List, Optional, Tuple @@ -10,21 +8,29 @@ from langgraph.graph import START, StateGraph from typing_extensions import TypedDict, Literal -from utils.other.endpoints import timeit - -# os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS') - -import database.chat as chat_db import database.memories as memories_db from database.redis_db import get_filter_category_items from database.vector_db import query_vectors_by_metadata from models.chat import Message from models.memory import Memory from models.plugin import Plugin -from utils.llm import requires_context, answer_simple_message, retrieve_context_dates, qa_rag, \ - select_structured_filters, extract_question_from_conversation, generate_embedding +from utils.llm import ( + answer_omi_question, + requires_context, + answer_simple_message, + retrieve_context_dates, + qa_rag, + retrieve_is_an_omi_question, + select_structured_filters, + extract_question_from_conversation, + generate_embedding, +) +from utils.other.endpoints import timeit +from utils.plugins import get_github_docs_content + +# os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS') -model = ChatOpenAI(model='gpt-4o-mini') +model = ChatOpenAI(model="gpt-4o-mini") class StructuredFilters(TypedDict): @@ -53,82 +59,101 @@ class GraphState(TypedDict): answer: Optional[str] -def determine_conversation_type(s: GraphState) -> Literal[ - "no_context_conversation", "context_dependent_conversation"]: - requires = requires_context(s.get('messages', [])) +def determine_conversation_type( + s: GraphState, +) -> Literal["no_context_conversation", "context_dependent_conversation", "no_context_omi_question"]: + is_omi_question = retrieve_is_an_omi_question(s.get("messages", [])) + if is_omi_question: + return "no_context_omi_question" + + requires = requires_context(s.get("messages", [])) if requires: - return 'context_dependent_conversation' - return 'no_context_conversation' + return "context_dependent_conversation" + return "no_context_conversation" def no_context_conversation(state: GraphState): - print('no_context_conversation node') - return {'answer': answer_simple_message(state.get('uid'), state.get('messages'))} + print("no_context_conversation node") + return {"answer": answer_simple_message(state.get("uid"), state.get("messages"))} + + +def no_context_omi_question(state: GraphState): + print("no_context_omi_question node") + context: dict = get_github_docs_content() + context_str = 'Documentation:\n\n'.join([f'{k}:\n {v}' for k, v in context.items()]) + answer = answer_omi_question(state.get("messages", []), context_str) + return {'answer': answer} def context_dependent_conversation(state: GraphState): - question = extract_question_from_conversation(state.get('messages', [])) - print('context_dependent_conversation parsed question:', question) - return {'parsed_question': question} + question = extract_question_from_conversation(state.get("messages", [])) + print("context_dependent_conversation parsed question:", question) + return {"parsed_question": question} # !! include a question extractor? node? def retrieve_topics_filters(state: GraphState): - print('retrieve_topics_filters') + print("retrieve_topics_filters") filters = { - 'people': get_filter_category_items(state.get('uid'), 'people'), - 'topics': get_filter_category_items(state.get('uid'), 'topics'), - 'entities': get_filter_category_items(state.get('uid'), 'entities'), + "people": get_filter_category_items(state.get("uid"), "people"), + "topics": get_filter_category_items(state.get("uid"), "topics"), + "entities": get_filter_category_items(state.get("uid"), "entities"), # 'dates': get_filter_category_items(state.get('uid'), 'dates'), } - result = select_structured_filters(state.get('parsed_question', ''), filters) - return {'filters': { - 'topics': result.get('topics', []), - 'people': result.get('people', []), - 'entities': result.get('entities', []), - # 'dates': result.get('dates', []), - }} + result = select_structured_filters(state.get("parsed_question", ""), filters) + return { + "filters": { + "topics": result.get("topics", []), + "people": result.get("people", []), + "entities": result.get("entities", []), + # 'dates': result.get('dates', []), + } + } def retrieve_date_filters(state: GraphState): - dates_range = retrieve_context_dates(state.get('messages', [])) + dates_range = retrieve_context_dates(state.get("messages", [])) if dates_range and len(dates_range) == 2: - return {'date_filters': {'start': dates_range[0], 'end': dates_range[1]}} - return {'date_filters': {}} + return {"date_filters": {"start": dates_range[0], "end": dates_range[1]}} + return {"date_filters": {}} def query_vectors(state: GraphState): - print('query_vectors') - date_filters = state.get('date_filters') - uid = state.get('uid') - vector = generate_embedding(state.get('parsed_question', '')) if state.get('parsed_question') else [0] * 3072 - print('query_vectors vector:', vector[:5]) + print("query_vectors") + date_filters = state.get("date_filters") + uid = state.get("uid") + vector = ( + generate_embedding(state.get("parsed_question", "")) + if state.get("parsed_question") + else [0] * 3072 + ) + print("query_vectors vector:", vector[:5]) memories_id = query_vectors_by_metadata( uid, vector, - dates_filter=[date_filters.get('start'), date_filters.get('end')], - people=state.get('filters', {}).get('people', []), - topics=state.get('filters', {}).get('topics', []), - entities=state.get('filters', {}).get('entities', []), - dates=state.get('filters', {}).get('dates', []), + dates_filter=[date_filters.get("start"), date_filters.get("end")], + people=state.get("filters", {}).get("people", []), + topics=state.get("filters", {}).get("topics", []), + entities=state.get("filters", {}).get("entities", []), + dates=state.get("filters", {}).get("dates", []), ) memories = memories_db.get_memories_by_id(uid, memories_id) - return {'memories_found': memories} + return {"memories_found": memories} def qa_handler(state: GraphState): - uid = state.get('uid') - memories = state.get('memories_found', []) + uid = state.get("uid") + memories = state.get("memories_found", []) response: str = qa_rag( uid, - state.get('parsed_question'), + state.get("parsed_question"), Memory.memories_to_string(memories, True), - state.get('plugin_selected') + state.get("plugin_selected"), ) - return {'answer': response} + return {"answer": response} workflow = StateGraph(GraphState) @@ -139,26 +164,27 @@ def qa_handler(state: GraphState): ) workflow.add_node("no_context_conversation", no_context_conversation) +workflow.add_node("no_context_omi_question", no_context_omi_question) workflow.add_node("context_dependent_conversation", context_dependent_conversation) workflow.add_edge("no_context_conversation", END) - +workflow.add_edge("no_context_omi_question", END) workflow.add_edge("context_dependent_conversation", "retrieve_topics_filters") workflow.add_edge("context_dependent_conversation", "retrieve_date_filters") workflow.add_node("retrieve_topics_filters", retrieve_topics_filters) workflow.add_node("retrieve_date_filters", retrieve_date_filters) -workflow.add_edge('retrieve_topics_filters', 'query_vectors') -workflow.add_edge('retrieve_date_filters', 'query_vectors') +workflow.add_edge("retrieve_topics_filters", "query_vectors") +workflow.add_edge("retrieve_date_filters", "query_vectors") -workflow.add_node('query_vectors', query_vectors) +workflow.add_node("query_vectors", query_vectors) -workflow.add_edge('query_vectors', 'qa_handler') +workflow.add_edge("query_vectors", "qa_handler") -workflow.add_node('qa_handler', qa_handler) +workflow.add_node("qa_handler", qa_handler) -workflow.add_edge('qa_handler', END) +workflow.add_edge("qa_handler", END) checkpointer = MemorySaver() graph = workflow.compile(checkpointer=checkpointer) @@ -166,26 +192,38 @@ def qa_handler(state: GraphState): @timeit def execute_graph_chat(uid: str, messages: List[Message]) -> Tuple[str, List[Memory]]: - result = graph.invoke({'uid': uid, 'messages': messages}, {"configurable": {"thread_id": str(uuid.uuid4())}}) - return result.get('answer'), result.get('memories_found', []) + result = graph.invoke( + {"uid": uid, "messages": messages}, + {"configurable": {"thread_id": str(uuid.uuid4())}}, + ) + return result.get("answer"), result.get("memories_found", []) def _pretty_print_conversation(messages: List[Message]): for msg in messages: - print(f'{msg.sender}: {msg.text}') + print(f"{msg.sender}: {msg.text}") -if __name__ == '__main__': - # uid = 'ccQJWj5mwhSY1dwjS1FPFBfKIXe2' +if __name__ == "__main__": + # graph.get_graph().draw_png("workflow.png") + uid = "ccQJWj5mwhSY1dwjS1FPFBfKIXe2" # def _send_message(text: str, sender: str = 'human'): # message = Message( # id=str(uuid.uuid4()), text=text, created_at=datetime.datetime.now(datetime.timezone.utc), sender=sender, # type='text' # ) # chat_db.add_message(uid, message.dict()) - - - graph.get_graph().draw_png('workflow.png') + messages = [ + Message( + id=str(uuid.uuid4()), + text="How can I build a plugin?", + created_at=datetime.datetime.now(datetime.timezone.utc), + sender="human", + type="text", + ) + ] + result = execute_graph_chat(uid, messages) + print("result:", print(result)) # messages = list(reversed([Message(**msg) for msg in chat_db.get_messages(uid, limit=10)])) # _pretty_print_conversation(messages) # # print(messages[-1].text) diff --git a/backend/utils/retrieval/rag.py b/backend/utils/retrieval/rag.py index e9597f973..c4d68143e 100644 --- a/backend/utils/retrieval/rag.py +++ b/backend/utils/retrieval/rag.py @@ -2,13 +2,11 @@ from collections import Counter, defaultdict from typing import List, Tuple -from database.memories import filter_memories_by_date, get_memories_by_id +from database.memories import get_memories_by_id from database.vector_db import query_vectors -from models.chat import Message from models.memory import Memory from models.transcript_segment import TranscriptSegment -from utils.llm import requires_context, retrieve_context_topics, retrieve_context_dates, chunk_extraction, \ - num_tokens_from_string, retrieve_memory_context_params +from utils.llm import chunk_extraction, num_tokens_from_string, retrieve_memory_context_params def retrieve_for_topic(uid: str, topic: str, start_timestamp, end_timestamp, k: int, memories_id) -> List[str]: @@ -56,59 +54,6 @@ def get_better_memory_chunk(memory: Memory, topics: List[str], context_data: dic context_data[memory.id] = chunk -def retrieve_rag_context( - uid: str, prev_messages: List[Message], return_context_params: bool = False -) -> Tuple[str, List[Memory]]: - requires = requires_context(prev_messages) - - if not requires: - return '', [] - - topics = retrieve_context_topics(prev_messages) - dates_range = retrieve_context_dates(prev_messages) - print('retrieve_rag_context', topics, dates_range) - if not topics and len(dates_range) != 2: - return '', [] - - if len(topics) > 5: - topics = topics[:5] - - memories_id_to_topics = {} - memories = None - if topics: - memories_id_to_topics, memories = retrieve_memories_for_topics(uid, topics, dates_range) - id_counter = Counter(memory['id'] for memory in memories) - memories = sorted(memories, key=lambda x: id_counter[x['id']], reverse=True) - - if not memories and len(dates_range) == 2: - memories_id_to_topics = {} - memories = filter_memories_by_date(uid, dates_range[0], dates_range[1]) - - memories = [Memory(**memory) for memory in memories] - if len(memories) > 10: - memories = memories[:10] - - # not performing as expected - if memories_id_to_topics: - context_data = {} - threads = [] - for memory in memories: - m_topics = memories_id_to_topics.get(memory.id, []) - t = threading.Thread(target=get_better_memory_chunk, args=(memory, m_topics, context_data)) - threads.append(t) - [t.start() for t in threads] - [t.join() for t in threads] - memories = list(filter(lambda x: x.id in context_data, memories)) - context_str = '\n\n---------------------\n\n'.join(context_data.values()).strip() - else: - context_str = Memory.memories_to_string(memories) - - if return_context_params: - return context_str, (memories if context_str else []), topics, dates_range - - return context_str, (memories if context_str else []) - - def retrieve_rag_memory_context(uid: str, memory: Memory) -> Tuple[str, List[Memory]]: topics = retrieve_memory_context_params(memory) print('retrieve_memory_rag_context', topics)