Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support memory context for proactive notification apps #1333

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions backend/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,10 +861,10 @@ def provide_advice_message(uid: str, segments: List[TranscriptSegment], context:


# **************************************************
# ************* MENTOR PLUGIN **************
# ************* PROACTIVE NOTIFICATION PLUGIN **************
# **************************************************

def get_metoring_message(uid: str, plugin_prompt: str, params: [str]) -> str:
def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context: str) -> str:
user_name, facts_str = get_prompt_facts(uid)

prompt = plugin_prompt
Expand All @@ -875,6 +875,10 @@ def get_metoring_message(uid: str, plugin_prompt: str, params: [str]) -> str:
if param == "user_facts":
prompt = prompt.replace("{{user_facts}}", facts_str)
continue
if param == "user_context":
prompt = prompt.replace("{{user_context}}", context if context else "")
continue
prompt = prompt.replace(' ', '').strip()
#print(prompt)

return llm_mini.invoke(prompt).content
84 changes: 70 additions & 14 deletions backend/utils/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from models.plugin import Plugin, UsageHistoryType
from utils.notifications import send_notification
from utils.other.endpoints import timeit
from utils.llm import get_metoring_message
from utils.llm import (
generate_embedding,
get_proactive_message
)
from database.vector_db import query_vectors_by_metadata
import database.memories as memories_db


def get_github_docs_content(repo="BasedHardware/omi", path="docs/docs"):
Expand Down Expand Up @@ -219,6 +224,65 @@ async def trigger_realtime_integrations(uid: str, segments: list[dict]):
_trigger_realtime_integrations(uid, token, segments)


# proactive notification
def _retrieve_contextual_memories(uid: str, user_context):
vector = (
generate_embedding(user_context.get('question', ''))
if user_context.get('question')
else [0] * 3072
)
print("query_vectors vector:", vector[:5])

date_filters = {} # not support yet
filters = user_context.get('filters', {})
memories_id = query_vectors_by_metadata(
uid,
vector,
dates_filter=[date_filters.get("start"), date_filters.get("end")],
people=filters.get("people", []),
topics=filters.get("topics", []),
entities=filters.get("entities", []),
dates=filters.get("dates", []),
)
return memories_db.get_memories_by_id(uid, memories_id)


def _process_proactive_notification(uid: str, token: str, plugin: Plugin, data):
if not plugin.has_capability("proactive_notification") or not data:
print(f"Plugins {plugin.id} is not proactive_notification or data invalid", uid)
return

max_prompt_char_limit = 8000
min_message_char_limit = 5

prompt = data.get('prompt', '')
if len(prompt) > max_prompt_char_limit:
send_plugin_notification(token, plugin.name, plugin.id, f"Prompt too long: {len(prompt)}/{max_prompt_char_limit} characters. Please shorten.")
print(f"Plugin {plugin.id}, prompt too long, length: {len(prompt)}/{max_prompt_char_limit}", uid)
return None

filter_scopes = plugin.fitler_proactive_notification_scopes(data.get('params', []))

# context
context = None
if 'user_context' in filter_scopes:
memories = _retrieve_contextual_memories(uid, data.get('context', {}))
if len(memories) > 0:
context = Memory.memories_to_string(memories, True)

print(f'_process_proactive_notification context {context[:100] if context else "empty"}')

# retrive message
message = get_proactive_message(uid, prompt, filter_scopes, context)
if not message or len(message) < min_message_char_limit:
print(f"Plugins {plugin.id}, message too short", uid)
return None

# send notification
send_plugin_notification(token, plugin.name, plugin.id, message)
return message


def _trigger_realtime_integrations(uid: str, token: str, segments: List[dict]) -> dict:
plugins: List[Plugin] = get_plugins_data_from_db(uid, include_reviews=False)
filtered_plugins = [
Expand Down Expand Up @@ -259,20 +323,12 @@ def _single(plugin: Plugin):
results[plugin.id] = message

# proactive_notification
noti = response_data.get('notification', None)
print('Plugin', plugin.id, 'response notification:', noti)
if plugin.has_capability("proactive_notification"):
noti = response_data.get('notification', None)
print('Plugin', plugin.id, 'response notification:', noti)
if noti:
prompt = noti.get('prompt', '')
if len(prompt) > 0 and len(prompt) <= 8000:
params = noti.get('params', [])
message = get_metoring_message(uid, prompt, plugin.fitler_proactive_notification_scopes(params))
if message and len(message) > 5:
send_plugin_notification(token, plugin.name, plugin.id, message)
results[plugin.id] = message
elif len(prompt) > 8000:
send_plugin_notification(token, plugin.name, plugin.id, f"Prompt too long: {len(prompt)}/8000 characters. Please shorten.")
print(f"Plugin {plugin.id} prompt too long, length: {len(prompt)}/8000")
message = _process_proactive_notification(uid, token, plugin, noti)
if message:
results[plugin.id] = message

except Exception as e:
print(f"Plugin integration error: {e}")
Expand Down
27 changes: 20 additions & 7 deletions plugins/example/basic/mentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from fastapi import APIRouter

from models import TranscriptSegment, MentorEndpointResponse, RealtimePluginRequest
from models import TranscriptSegment, ProactiveNotificationEndpointResponse, RealtimePluginRequest
from db import get_upsert_segment_to_transcript_plugin

router = APIRouter()

scan_segment_session = {}

# *******************************************************
# ************ Basic Mentor Plugin ************
# ************ Basic Proactive Notification Plugin ************
# *******************************************************

@router.post('/mentor', tags=['mentor', 'basic', 'realtime'], response_model=MentorEndpointResponse)
@router.post('/mentor', tags=['mentor', 'basic', 'realtime', 'proactive_notification'], response_model=ProactiveNotificationEndpointResponse, response_model_exclude_none=True)
def mentoring(data: RealtimePluginRequest):
def normalize(text):
return re.sub(r' +', ' ',re.sub(r'[,?.!]', ' ', text)).lower().strip()
Expand Down Expand Up @@ -43,6 +43,7 @@ def normalize(text):

user_name = "{{user_name}}"
user_facts = "{{user_facts}}"
user_context = "{{user_context}}"

prompt = f"""
You are an experienced mentor, that helps people achieve their goals during the meeting.
Expand All @@ -67,15 +68,27 @@ def normalize(text):
Output your response in plain text, without markdown.

If you cannot find the topic or problem of the meeting, respond 'Nah 🤷 ~'.

Conversation:
```
${transcript}
```

Context:
```
{user_context}
```
""".replace(' ', '').strip()

# 3. Respond with the format {notification: {prompt, params}}
return {'session_id': data.session_id,
'notification': {'prompt': prompt,
'params': ['user_name', 'user_facts']}}
# 3. Respond with the format {notification: {prompt, params, context}}
# - context: {question, filters: {people, topics, entities}} | None
return {
'session_id': data.session_id,
'notification': {
'prompt': prompt,
'params': ['user_name', 'user_facts', 'user_context'],
}
}

@ router.get('/setup/mentor', tags=['mentor'])
def is_setup_completed(uid: str):
Expand Down
16 changes: 13 additions & 3 deletions plugins/example/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,20 @@ class RealtimePluginRequest(BaseModel):
segments: List[TranscriptSegment]


class MentorResponse(BaseModel):
class ProactiveNotificationContextFitlersResponse(BaseModel):
people: List[str] = Field(description="A list of people. ", default=[])
entities: List[str] = Field(description="A list of entity. ", default=[])
topics: List[str] = Field(description="A list of topic. ", default=[])

class ProactiveNotificationContextResponse(BaseModel):
question: str = Field(description="A question to query the embeded vector database.", default='')
filters: ProactiveNotificationContextFitlersResponse = Field(description="Filter options to query the embeded vector database. ", default=None)

class ProactiveNotificationResponse(BaseModel):
prompt: str = Field(description="A prompt or a template with the parameters such as {{user_name}} {{user_facts}}.", default='')
params: List[str] = Field(description="A list of string that match with proactive notification scopes. ", default=[])
context: ProactiveNotificationContextResponse = Field(description="An object to guide the system in retrieving the users context", default=None)

class MentorEndpointResponse(BaseModel):
class ProactiveNotificationEndpointResponse(BaseModel):
message: str = Field(description="A short message to be sent as notification to the user, if needed.", default='')
notification: MentorResponse = Field(description="An object to guide the system in generating the proactive notification", default=None)
notification: ProactiveNotificationResponse = Field(description="An object to guide the system in generating the proactive notification", default=None)