Skip to content

Commit

Permalink
memory creation trigger when websocket is connected
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Oct 5, 2024
1 parent ad6940a commit 76a9bc5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 54 deletions.
11 changes: 0 additions & 11 deletions backend/database/redis_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,3 @@ def get_in_progress_memory_id(uid: str) -> str:
return ''
return memory_id.decode()


def set_in_progress_memory_id_last_segment_seconds(uid: str, seconds: str, ttl: int = 150):
r.set(f'users:{uid}:in_progress_memory_id_last_segment_seconds', seconds)
r.expire(f'users:{uid}:in_progress_memory_id_last_segment_seconds', ttl)


def get_in_progress_memory_id_last_segment_seconds(uid: str) -> float:
memory_id = r.get(f'users:{uid}:in_progress_memory_id_last_segment_seconds')
if not memory_id:
return 0
return float(memory_id.decode())
96 changes: 53 additions & 43 deletions backend/routers/transcribe_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _get_or_create_in_progress_memory(segments: List[dict]):
memory.transcript_segments, [TranscriptSegment(**segment) for segment in segments]
)
redis_db.set_in_progress_memory_id(uid, memory.id)
redis_db.set_in_progress_memory_id_last_segment_seconds(uid, memory.transcript_segments[-1].end)
return memory

started_at = datetime.now(timezone.utc) - timedelta(seconds=segments[0]['end'] - segments[0]['start'])
Expand All @@ -138,26 +137,69 @@ def _get_or_create_in_progress_memory(segments: List[dict]):
print('_get_in_progress_memory new', memory)
memories_db.upsert_memory(uid, memory_data=memory.dict())
redis_db.set_in_progress_memory_id(uid, memory.id)
redis_db.set_in_progress_memory_id_last_segment_seconds(uid, segments[-1]['end'])
return memory

async def memory_creation_timer():
async def _send_message_event(msg: MessageEvent):
print(f"Message: type ${msg.event_type}")
try:
await websocket.send_json(msg.to_json())
return True
except WebSocketDisconnect:
print("WebSocket disconnected")
except RuntimeError as e:
print(f"Can not send message event, error: {e}")

return False

async def memory_creation_timer(delay_seconds: int):
print('memory_creation_timer', delay_seconds)
try:
await asyncio.sleep(memory_creation_timeout)
await asyncio.sleep(delay_seconds)
await _create_memory()
except asyncio.CancelledError:
pass

async def _create_memory():
print("_create_memory")
# Reset state variables
nonlocal seconds_to_trim
nonlocal seconds_to_add
seconds_to_trim = None
seconds_to_add = None

memory = retrieve_in_progress_memory(uid)
if not memory or not memory['transcript_segments']:
raise Exception('FAILED')
memory = Memory(**memory)

asyncio.create_task(_send_message_event(MemoryEvent(event_type="memory_processing_started", memory=memory)))

memories_db.update_memory_status(uid, memory.id, MemoryStatus.processing)
memory = process_memory(uid, language, memory)
memories_db.update_memory_status(uid, memory.id, MemoryStatus.completed)
messages = trigger_external_integrations(uid, memory)

asyncio.create_task(
_send_message_event(MemoryEvent(event_type="memory_created", memory=memory, messages=messages))
)

memory_creation_task = None
seconds_to_trim = None
seconds_to_add = None

existing_memory = retrieve_in_progress_memory(uid)
if existing_memory:
dt = datetime.fromisoformat(existing_memory['started_at'].isoformat())
seconds_to_add = (datetime.now(timezone.utc) - dt).total_seconds()
# TODO: validate is not more than duration? 120 seconds, and start processing
print('seconds_to_add', seconds_to_add)
# Determine previous disconnected socket seconds to add + start processing timer if a memory in progress
if existing_memory := retrieve_in_progress_memory(uid):
started_at = datetime.fromisoformat(existing_memory['started_at'].isoformat())
seconds_to_add = (datetime.now(timezone.utc) - started_at).total_seconds()

finished_at = datetime.fromisoformat(existing_memory['finished_at'].isoformat())
seconds_since_last_segment = (datetime.now(timezone.utc) - finished_at).total_seconds()
if seconds_since_last_segment >= memory_creation_timeout:
asyncio.create_task(_create_memory())
else:
memory_creation_task = asyncio.create_task(
memory_creation_timer(memory_creation_timeout - seconds_since_last_segment)
)

def stream_transcript(segments, _):
nonlocal websocket
Expand All @@ -173,7 +215,7 @@ def stream_transcript(segments, _):

if memory_creation_task is not None:
memory_creation_task.cancel()
memory_creation_task = asyncio.create_task(memory_creation_timer())
memory_creation_task = asyncio.create_task(memory_creation_timer(memory_creation_timeout))

# Segments aligning duration seconds.
if seconds_to_add:
Expand Down Expand Up @@ -339,38 +381,6 @@ async def send_heartbeat():
finally:
websocket_active = False

async def _send_message_event(msg: MessageEvent):
print(f"Message: type ${msg.event_type}")
try:
await websocket.send_json(msg.to_json())
return True
except WebSocketDisconnect:
print("WebSocket disconnected")
except RuntimeError as e:
print(f"Can not send message event, error: {e}")

return False

async def _create_memory():
print("_create_memory")
nonlocal seconds_to_trim
seconds_to_trim = None

memory = retrieve_in_progress_memory(uid)
if not memory or not memory.transcript_segments:
raise Exception('FAILED')

asyncio.create_task(_send_message_event(MemoryEvent(event_type="memory_processing_started", memory=memory)))

memories_db.update_memory_status(uid, memory.id, MemoryStatus.processing)
memory = process_memory(uid, language, memory)
memories_db.update_memory_status(uid, memory.id, MemoryStatus.completed)
messages = trigger_external_integrations(uid, memory)

asyncio.create_task(
_send_message_event(MemoryEvent(event_type="memory_created", memory=memory, messages=messages))
)

try:
receive_task = asyncio.create_task(
receive_audio(deepgram_socket, deepgram_socket2, soniox_socket, speechmatics_socket)
Expand Down

0 comments on commit 76a9bc5

Please sign in to comment.