Skip to content

Commit

Permalink
Add pusher reconnect (#1296)
Browse files Browse the repository at this point in the history
Backend should reconnect to Pusher if there any wrong with the
connection.

# TODOs
- [x] Add reconnects logic
- [x] With basic exponential back-off + jitter  

# Example
<img width="771" alt="Screenshot 2024-11-12 at 18 47 21"
src="https://github.com/user-attachments/assets/ab3de78d-2654-4eba-a182-513cf415afcc">
  • Loading branch information
beastoin authored Nov 13, 2024
2 parents a14fcfc + 6abd0b1 commit 0398b3b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 241 deletions.
189 changes: 4 additions & 185 deletions backend/routers/pusher.py
Original file line number Diff line number Diff line change
@@ -1,198 +1,17 @@
import uuid
import struct
from datetime import datetime, timezone, timedelta
from enum import Enum
import asyncio
import json

import opuslib
import webrtcvad
from fastapi import APIRouter
from fastapi.websockets import WebSocketDisconnect, WebSocket
from pydub import AudioSegment
from starlette.websockets import WebSocketState

import database.memories as memories_db
from database import redis_db
from database.redis_db import get_cached_user_geolocation
from models.memory import Memory, TranscriptSegment, MemoryStatus, Structured, Geolocation
from models.message_event import MemoryEvent, MessageEvent
from utils.memories.location import get_google_maps_location
from utils.memories.process_memory import process_memory
from utils.plugins import trigger_external_integrations, trigger_realtime_integrations
from utils.stt.streaming import *
from utils.plugins import trigger_realtime_integrations
from utils.webhooks import send_audio_bytes_developer_webhook, realtime_transcript_webhook, \
get_audio_bytes_webhook_seconds

router = APIRouter()

async def _websocket_util_transcript(
websocket: WebSocket, uid: str,
):
print('_websocket_util_transcript', uid)

try:
await websocket.accept()
except RuntimeError as e:
print(e)
await websocket.close(code=1011, reason="Dirty state")
return

websocket_active = True
websocket_close_code = 1000

loop = asyncio.get_event_loop()

# task
async def receive_segments():
nonlocal websocket_active
nonlocal websocket_close_code

try:
while websocket_active:
segments = await websocket.receive_json()
# print(f"pusher received segments {len(segments)}")
asyncio.run_coroutine_threadsafe(trigger_realtime_integrations(uid, segments), loop)
asyncio.run_coroutine_threadsafe(realtime_transcript_webhook(uid, segments), loop)

except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f'Could not process segments: error {e}')
websocket_close_code = 1011
finally:
websocket_active = False

# heart beat
async def send_heartbeat():
nonlocal websocket_active
nonlocal websocket_close_code
try:
while websocket_active:
await asyncio.sleep(20)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
break
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f'Heartbeat error: {e}')
websocket_close_code = 1011
finally:
websocket_active = False

try:
receive_task = asyncio.create_task(
receive_segments()
)
heartbeat_task = asyncio.create_task(send_heartbeat())
await asyncio.gather(receive_task, heartbeat_task)

except Exception as e:
print(f"Error during WebSocket operation: {e}")
finally:
websocket_active = False
if websocket.client_state == WebSocketState.CONNECTED:
try:
await websocket.close(code=websocket_close_code)
except Exception as e:
print(f"Error closing WebSocket: {e}")


@router.websocket("/v1/trigger/transcript/listen")
async def websocket_endpoint_transcript(
websocket: WebSocket, uid: str,
):
await _websocket_util_transcript(websocket, uid)


async def _websocket_util_audio_bytes(
websocket: WebSocket, uid: str, sample_rate: int = 8000,
):
print('_websocket_util_audio_bytes', uid)

try:
await websocket.accept()
except RuntimeError as e:
print(e)
await websocket.close(code=1011, reason="Dirty state")
return

websocket_active = True
websocket_close_code = 1000

loop = asyncio.get_event_loop()

audio_bytes_webhook_delay_seconds = get_audio_bytes_webhook_seconds(uid)

# task
async def receive_audio_bytes():
nonlocal websocket_active
nonlocal websocket_close_code

audiobuffer = bytearray()

try:
while websocket_active:
data = await websocket.receive_bytes()
# print(f"pusher received audio bytes {len(data)}")
audiobuffer.extend(data)
if audio_bytes_webhook_delay_seconds and len(
audiobuffer) > sample_rate * audio_bytes_webhook_delay_seconds * 2:
asyncio.create_task(send_audio_bytes_developer_webhook(uid, sample_rate, audiobuffer.copy()))
audiobuffer = bytearray()

except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f'Could not process audio: error {e}')
websocket_close_code = 1011
finally:
websocket_active = False

# heart beat
async def send_heartbeat():
nonlocal websocket_active
nonlocal websocket_close_code
try:
while websocket_active:
await asyncio.sleep(20)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
break
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f'Heartbeat error: {e}')
websocket_close_code = 1011
finally:
websocket_active = False

try:
receive_task = asyncio.create_task(
receive_audio_bytes()
)
heartbeat_task = asyncio.create_task(send_heartbeat())
await asyncio.gather(receive_task, heartbeat_task)

except Exception as e:
print(f"Error during WebSocket operation: {e}")
finally:
websocket_active = False
if websocket.client_state == WebSocketState.CONNECTED:
try:
await websocket.close(code=websocket_close_code)
except Exception as e:
print(f"Error closing WebSocket: {e}")


@router.websocket("/v1/trigger/audio-bytes/listen")
async def websocket_endpoint_audio_bytes(
websocket: WebSocket, uid: str, sample_rate: int = 8000,
):
await _websocket_util_audio_bytes(websocket, uid, sample_rate)


async def _websocket_util_trigger(
websocket: WebSocket, uid: str, sample_rate: int = 8000,
):
Expand Down Expand Up @@ -237,7 +56,7 @@ async def receive_audio_bytes():
audiobuffer.extend(data[4:])
if audio_bytes_webhook_delay_seconds and len(
audiobuffer) > sample_rate * audio_bytes_webhook_delay_seconds * 2:
asyncio.create_task(send_audio_bytes_developer_webhook(uid, sample_rate, audiobuffer.copy()))
asyncio.run_coroutine_threadsafe(send_audio_bytes_developer_webhook(uid, sample_rate, audiobuffer.copy()), loop)
audiobuffer = bytearray()
continue

Expand Down
39 changes: 32 additions & 7 deletions backend/routers/transcribe_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
import asyncio
import struct
from datetime import datetime, timezone, timedelta
from enum import Enum
Expand All @@ -21,7 +22,7 @@
from utils.stt.streaming import *
from utils.webhooks import send_audio_bytes_developer_webhook, realtime_transcript_webhook, \
get_audio_bytes_webhook_seconds
from utils.pusher import connect_to_transcript_pusher, connect_to_audio_bytes_pusher, connect_to_trigger_pusher
from utils.pusher import connect_to_trigger_pusher

router = APIRouter()

Expand Down Expand Up @@ -297,6 +298,8 @@ def create_pusher_task_handler():
nonlocal websocket_active

pusher_ws = None
pusher_connect_lock = asyncio.Lock()
pusher_connected = False

# Transcript
transcript_ws = None
Expand All @@ -309,9 +312,10 @@ def transcript_send(segments):
async def transcript_consume():
nonlocal websocket_active
nonlocal segment_buffers
nonlocal transcript_ws
while websocket_active or len(segment_buffers) > 0:
await asyncio.sleep(1)
if transcript_ws and len(segment_buffers) > 0:
if transcript_ws:
try:
# 100|data
data = bytearray()
Expand All @@ -321,6 +325,9 @@ async def transcript_consume():
await transcript_ws.send(data)
except websockets.exceptions.ConnectionClosed as e:
print(f"Pusher transcripts Connection closed: {e}", uid)
transcript_ws = None
pusher_connected = False
await reconnect()
except Exception as e:
print(f"Pusher transcripts failed: {e}", uid)

Expand All @@ -336,9 +343,10 @@ def audio_bytes_send(audio_bytes):
async def audio_bytes_consume():
nonlocal websocket_active
nonlocal audio_buffers
nonlocal audio_bytes_ws
while websocket_active or len(audio_buffers) > 0:
await asyncio.sleep(1)
if audio_bytes_ws and len(audio_buffers) > 0:
if audio_bytes_ws:
try:
# 101|data
data = bytearray()
Expand All @@ -348,18 +356,35 @@ async def audio_bytes_consume():
await audio_bytes_ws.send(data)
except websockets.exceptions.ConnectionClosed as e:
print(f"Pusher audio_bytes Connection closed: {e}", uid)
audio_bytes_ws = None
pusher_connected = False
await reconnect()
except Exception as e:
print(f"Pusher audio_bytes failed: {e}", uid)

async def reconnect():
nonlocal pusher_connected
nonlocal pusher_connect_lock
with pusher_connect_lock:
if pusher_connected:
return
await connect()

async def connect():
nonlocal pusher_ws
nonlocal transcript_ws
nonlocal audio_bytes_ws
nonlocal audio_bytes_enabled
pusher_ws = await connect_to_trigger_pusher(uid, sample_rate)
transcript_ws = pusher_ws
if audio_bytes_enabled:
audio_bytes_ws = pusher_ws
nonlocal pusher_connected

try:
pusher_ws = await connect_to_trigger_pusher(uid, sample_rate)
pusher_connected = True
transcript_ws = pusher_ws
if audio_bytes_enabled:
audio_bytes_ws = pusher_ws
except Exception as e:
print(f"Exception in connect: {e}")

async def close(code: int = 1000):
await pusher_ws.close(code)
Expand Down
66 changes: 24 additions & 42 deletions backend/utils/pusher.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,26 @@
import uuid
import os
from datetime import datetime, timezone, timedelta
from enum import Enum

import opuslib
import webrtcvad
from fastapi import APIRouter
from fastapi.websockets import WebSocketDisconnect, WebSocket
from pydub import AudioSegment
from starlette.websockets import WebSocketState

import database.memories as memories_db
from database import redis_db
from database.redis_db import get_cached_user_geolocation
from models.memory import Memory, TranscriptSegment, MemoryStatus, Structured, Geolocation
from models.message_event import MemoryEvent, MessageEvent
from utils.memories.location import get_google_maps_location
from utils.memories.process_memory import process_memory
from utils.stt.streaming import *
from utils.webhooks import send_audio_bytes_developer_webhook, realtime_transcript_webhook, \
get_audio_bytes_webhook_seconds
import random
import asyncio
import websockets

PusherAPI = os.getenv('HOSTED_PUSHER_API_URL')

async def connect_to_transcript_pusher(uid: str):
try:
print("Connecting to Pusher transcripts trigger WebSocket...")
ws_host = PusherAPI.replace("http", "ws")
socket = await websockets.connect(f"{ws_host}/v1/trigger/transcript/listen?uid={uid}")
print("Connected to Pusher transcripts trigger WebSocket.")
return socket
except Exception as e:
print(f"Exception in connect_to_transcript_pusher: {e}")
raise
async def connect_to_trigger_pusher(uid: str, sample_rate: int = 8000, retries: int = 3):
print("connect_to_trigger_pusher")
for attempt in range(retries):
try:
return await _connect_to_trigger_pusher(uid, sample_rate)
except Exception as error:
print(f'An error occurred: {error}')
if attempt == retries - 1:
raise
backoff_delay = calculate_backoff_with_jitter(attempt)
print(f"Waiting {backoff_delay:.0f}ms before next retry...")
asyncio.sleep(backoff_delay / 1000)

async def connect_to_audio_bytes_pusher(uid: str, sample_rate: int = 8000):
try:
print("Connecting to Pusher audio bytes trigger WebSocket...")
ws_host = PusherAPI.replace("http", "ws")
socket = await websockets.connect(f"{ws_host}/v1/trigger/audio-bytes/listen?uid={uid}&sample_rate={sample_rate}")
print("Connected to Pusher audio bytes trigger WebSocket.")
return socket
except Exception as e:
print(f"Exception in connect_to_audio_bytes_pusher: {e}")
raise
raise Exception(f'Could not open socket: All retry attempts failed.')

async def connect_to_trigger_pusher(uid: str, sample_rate: int = 8000):
async def _connect_to_trigger_pusher(uid: str, sample_rate: int = 8000):
try:
print("Connecting to Pusher transcripts trigger WebSocket...", uid)
ws_host = PusherAPI.replace("http", "ws")
Expand All @@ -55,3 +30,10 @@ async def connect_to_trigger_pusher(uid: str, sample_rate: int = 8000):
except Exception as e:
print(f"Exception in connect_to_transcript_pusher: {e}", uid)
raise


# Calculate backoff with jitter
def calculate_backoff_with_jitter(attempt, base_delay=1000, max_delay=15000):
jitter = random.random() * base_delay
backoff = min(((2 ** attempt) * base_delay) + jitter, max_delay)
return backoff
Loading

0 comments on commit 0398b3b

Please sign in to comment.