From 76a9bc5e95fae28d8b8919452d2a93b9820133ee Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sat, 5 Oct 2024 00:51:46 -0700 Subject: [PATCH] memory creation trigger when websocket is connected --- backend/database/redis_db.py | 11 ---- backend/routers/transcribe_v2.py | 96 ++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/backend/database/redis_db.py b/backend/database/redis_db.py index 88bdc2189..9eb5ea95c 100644 --- a/backend/database/redis_db.py +++ b/backend/database/redis_db.py @@ -176,14 +176,3 @@ def get_in_progress_memory_id(uid: str) -> str: return '' return memory_id.decode() - -def set_in_progress_memory_id_last_segment_seconds(uid: str, seconds: str, ttl: int = 150): - r.set(f'users:{uid}:in_progress_memory_id_last_segment_seconds', seconds) - r.expire(f'users:{uid}:in_progress_memory_id_last_segment_seconds', ttl) - - -def get_in_progress_memory_id_last_segment_seconds(uid: str) -> float: - memory_id = r.get(f'users:{uid}:in_progress_memory_id_last_segment_seconds') - if not memory_id: - return 0 - return float(memory_id.decode()) diff --git a/backend/routers/transcribe_v2.py b/backend/routers/transcribe_v2.py index 071e16756..9a5cab52d 100644 --- a/backend/routers/transcribe_v2.py +++ b/backend/routers/transcribe_v2.py @@ -120,7 +120,6 @@ def _get_or_create_in_progress_memory(segments: List[dict]): memory.transcript_segments, [TranscriptSegment(**segment) for segment in segments] ) redis_db.set_in_progress_memory_id(uid, memory.id) - redis_db.set_in_progress_memory_id_last_segment_seconds(uid, memory.transcript_segments[-1].end) return memory started_at = datetime.now(timezone.utc) - timedelta(seconds=segments[0]['end'] - segments[0]['start']) @@ -138,26 +137,69 @@ def _get_or_create_in_progress_memory(segments: List[dict]): print('_get_in_progress_memory new', memory) memories_db.upsert_memory(uid, memory_data=memory.dict()) redis_db.set_in_progress_memory_id(uid, memory.id) - redis_db.set_in_progress_memory_id_last_segment_seconds(uid, segments[-1]['end']) return memory - async def memory_creation_timer(): + async def _send_message_event(msg: MessageEvent): + print(f"Message: type ${msg.event_type}") + try: + await websocket.send_json(msg.to_json()) + return True + except WebSocketDisconnect: + print("WebSocket disconnected") + except RuntimeError as e: + print(f"Can not send message event, error: {e}") + + return False + + async def memory_creation_timer(delay_seconds: int): + print('memory_creation_timer', delay_seconds) try: - await asyncio.sleep(memory_creation_timeout) + await asyncio.sleep(delay_seconds) await _create_memory() except asyncio.CancelledError: pass + async def _create_memory(): + print("_create_memory") + # Reset state variables + nonlocal seconds_to_trim + nonlocal seconds_to_add + seconds_to_trim = None + seconds_to_add = None + + memory = retrieve_in_progress_memory(uid) + if not memory or not memory['transcript_segments']: + raise Exception('FAILED') + memory = Memory(**memory) + + asyncio.create_task(_send_message_event(MemoryEvent(event_type="memory_processing_started", memory=memory))) + + memories_db.update_memory_status(uid, memory.id, MemoryStatus.processing) + memory = process_memory(uid, language, memory) + memories_db.update_memory_status(uid, memory.id, MemoryStatus.completed) + messages = trigger_external_integrations(uid, memory) + + asyncio.create_task( + _send_message_event(MemoryEvent(event_type="memory_created", memory=memory, messages=messages)) + ) + memory_creation_task = None seconds_to_trim = None seconds_to_add = None - existing_memory = retrieve_in_progress_memory(uid) - if existing_memory: - dt = datetime.fromisoformat(existing_memory['started_at'].isoformat()) - seconds_to_add = (datetime.now(timezone.utc) - dt).total_seconds() - # TODO: validate is not more than duration? 120 seconds, and start processing - print('seconds_to_add', seconds_to_add) + # Determine previous disconnected socket seconds to add + start processing timer if a memory in progress + if existing_memory := retrieve_in_progress_memory(uid): + started_at = datetime.fromisoformat(existing_memory['started_at'].isoformat()) + seconds_to_add = (datetime.now(timezone.utc) - started_at).total_seconds() + + finished_at = datetime.fromisoformat(existing_memory['finished_at'].isoformat()) + seconds_since_last_segment = (datetime.now(timezone.utc) - finished_at).total_seconds() + if seconds_since_last_segment >= memory_creation_timeout: + asyncio.create_task(_create_memory()) + else: + memory_creation_task = asyncio.create_task( + memory_creation_timer(memory_creation_timeout - seconds_since_last_segment) + ) def stream_transcript(segments, _): nonlocal websocket @@ -173,7 +215,7 @@ def stream_transcript(segments, _): if memory_creation_task is not None: memory_creation_task.cancel() - memory_creation_task = asyncio.create_task(memory_creation_timer()) + memory_creation_task = asyncio.create_task(memory_creation_timer(memory_creation_timeout)) # Segments aligning duration seconds. if seconds_to_add: @@ -339,38 +381,6 @@ async def send_heartbeat(): finally: websocket_active = False - async def _send_message_event(msg: MessageEvent): - print(f"Message: type ${msg.event_type}") - try: - await websocket.send_json(msg.to_json()) - return True - except WebSocketDisconnect: - print("WebSocket disconnected") - except RuntimeError as e: - print(f"Can not send message event, error: {e}") - - return False - - async def _create_memory(): - print("_create_memory") - nonlocal seconds_to_trim - seconds_to_trim = None - - memory = retrieve_in_progress_memory(uid) - if not memory or not memory.transcript_segments: - raise Exception('FAILED') - - asyncio.create_task(_send_message_event(MemoryEvent(event_type="memory_processing_started", memory=memory))) - - memories_db.update_memory_status(uid, memory.id, MemoryStatus.processing) - memory = process_memory(uid, language, memory) - memories_db.update_memory_status(uid, memory.id, MemoryStatus.completed) - messages = trigger_external_integrations(uid, memory) - - asyncio.create_task( - _send_message_event(MemoryEvent(event_type="memory_created", memory=memory, messages=messages)) - ) - try: receive_task = asyncio.create_task( receive_audio(deepgram_socket, deepgram_socket2, soniox_socket, speechmatics_socket)