Skip to content

Commit

Permalink
Merge branch 'main' into frxxwiaf_deepgram_obns3
Browse files Browse the repository at this point in the history
  • Loading branch information
beastoin committed Sep 23, 2024
2 parents c51616f + 4faf7cc commit 4592649
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 45 deletions.
2 changes: 2 additions & 0 deletions backend/routers/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def postprocess_memory(
TODO: post llm process here would be great, sometimes whisper x outputs without punctuation
"""

# TODO: this pipeline vs groq+pyannote diarization 3.1, probably the latter is better.

# Save file
file_path = f"_temp/{memory_id}_{file.filename}"
with open(file_path, 'wb') as f:
Expand Down
48 changes: 25 additions & 23 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import threading
import math
import asyncio
import time
from typing import List
import threading
import uuid
from datetime import datetime, timezone
from enum import Enum
Expand All @@ -16,7 +13,6 @@

import database.memories as memories_db
import database.processing_memories as processing_memories_db
from database.redis_db import get_user_speech_profile
from models.memory import Memory, TranscriptSegment
from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated
from models.processing_memory import ProcessingMemory
Expand Down Expand Up @@ -78,6 +74,8 @@ class STTService(str, Enum):
soniox = "soniox"
speechmatics = "speechmatics"

# auto = "auto"

@staticmethod
def get_model_name(value):
if value == STTService.deepgram:
Expand All @@ -96,15 +94,15 @@ async def _websocket_util(
print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch,
stt_service)

if stt_service == STTService.soniox and (
sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages):
stt_service = STTService.deepgram
if stt_service == STTService.speechmatics and (sample_rate != 16000 or codec != 'opus'):
if stt_service == STTService.soniox and language not in soniox_valid_languages:
stt_service = STTService.deepgram # defaults to deepgram

if stt_service == STTService.speechmatics: # defaults to deepgram (no credits + 10 connections max limit)
stt_service = STTService.deepgram

# At some point try running all the models together to easily compare
# TODO: if language english, use soniox
# TODO: else deepgram, if speechmatics credits, prob this for both?

# Check: Why do we need try-catch around websocket.accept?
try:
await websocket.accept()
except RuntimeError as e:
Expand Down Expand Up @@ -203,34 +201,36 @@ def stream_audio(audio_buffer):
duration = 0
try:
file_path, duration = None, 0
# TODO: how bee does for recognizing other languages speech profile
if language == 'en' and (codec == 'opus' or codec == 'pcm16') and include_speech_profile:
file_path = get_profile_audio_if_exists(uid)
print(f'deepgram-obns3: file_path {file_path}')
duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0

# Deepgram
# DEEPGRAM
if stt_service == STTService.deepgram:
deepgram_codec_value = 'pcm16' if codec == 'opus' else codec
deepgram_socket = await process_audio_dg(
stream_transcript, memory_stream_id, language, sample_rate, deepgram_codec_value, channels,
preseconds=duration
stream_transcript, memory_stream_id, language, sample_rate, channels, preseconds=duration
)
if duration:
deepgram_socket2 = await process_audio_dg(
stream_transcript, speech_profile_stream_id, language, sample_rate, deepgram_codec_value, channels
stream_transcript, speech_profile_stream_id, language, sample_rate, channels
)

print(f'deepgram-obns3: send_initial_file_path > deepgram_socket {deepgram_socket}')
async def deepgram_socket_send(data):
return deepgram_socket.send(data)
await send_initial_file_path(file_path, deepgram_socket_send)
# SONIOX
elif stt_service == STTService.soniox:
soniox_socket = await process_audio_soniox(
stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None
stream_transcript, speech_profile_stream_id, sample_rate, language,
uid if include_speech_profile else None
)
# SPEECHMATICS
elif stt_service == STTService.speechmatics:
speechmatics_socket = await process_audio_speechmatics(
stream_transcript, speech_profile_stream_id, language, preseconds=duration
stream_transcript, speech_profile_stream_id, sample_rate, language, preseconds=duration
)
if duration:
await send_initial_file_path(file_path, speechmatics_socket.send)
Expand Down Expand Up @@ -372,7 +372,7 @@ async def _create_processing_memory():
id=str(uuid.uuid4()),
created_at=datetime.now(timezone.utc),
timer_start=timer_start,
timer_segment_start=timer_start+segment_start,
timer_segment_start=timer_start + segment_start,
language=language,
)

Expand Down Expand Up @@ -435,7 +435,8 @@ async def _post_process_memory(memory: Memory):
# merge
merge_file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be"
nearest_timer_start = processing_memory.timer_starts[-2]
merge_wav_files(merge_file_path, [previous_file_path, file_path], [math.ceil(timer_start-nearest_timer_start), 0])
merge_wav_files(merge_file_path, [previous_file_path, file_path],
[math.ceil(timer_start - nearest_timer_start), 0])

# clean
os.remove(previous_file_path)
Expand Down Expand Up @@ -504,8 +505,8 @@ async def _create_memory():
memory = None
messages = []
if not processing_memory.memory_id:
(new_memory, new_messages, updated_processing_memory) = await create_memory_by_processing_memory(uid,
processing_memory.id)
new_memory, new_messages, updated_processing_memory = await create_memory_by_processing_memory(
uid, processing_memory.id)
if not new_memory:
print("Can not create new memory")

Expand All @@ -532,7 +533,8 @@ async def _create_memory():
[segment.dict() for segment in memory.transcript_segments])

# Update finished at
memory.finished_at = datetime.fromtimestamp(memory.started_at.timestamp() + processing_memory.transcript_segments[-1].end, timezone.utc)
memory.finished_at = datetime.fromtimestamp(
memory.started_at.timestamp() + processing_memory.transcript_segments[-1].end, timezone.utc)
memories_db.update_memory_finished_at(uid, memory.id, memory.finished_at)

# Process
Expand Down
4 changes: 2 additions & 2 deletions backend/scripts/stt/k_compare_transcripts_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def stream_transcript_speechmatics(new_segments, _):

# streaming models
socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0)
socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 'en', None)
socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 'en', 0)
socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 16000, 'en', None)
socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 16000, 'en', 0)
print('duration', duration)
with open(file_path, "rb") as file:
while True:
Expand Down
31 changes: 11 additions & 20 deletions backend/utils/stt/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ async def send_initial_file(data: List[List[int]], transcript_socket):


async def process_audio_dg(
stream_transcript, stream_id: int, language: str, sample_rate: int, codec: str, channels: int,
stream_transcript, stream_id: int, language: str, sample_rate: int, channels: int,
preseconds: int = 0,
):
print('process_audio_dg', language, sample_rate, codec, channels, preseconds)
print('process_audio_dg', language, sample_rate, channels, preseconds)

def on_message(self, result, **kwargs):
# print(f"Received message from Deepgram") # Log when message is received
Expand Down Expand Up @@ -143,15 +143,15 @@ def on_error(self, error, **kwargs):
print(f"Error: {error}")

print("Connecting to Deepgram") # Log before connection attempt
return connect_to_deepgram(on_message, on_error, language, sample_rate, codec, channels)
return connect_to_deepgram(on_message, on_error, language, sample_rate, channels)


def process_segments(uid: str, segments: list[dict]):
token = notification_db.get_token_only(uid) # TODO: don't retrieve token before knowing if to notify
trigger_realtime_integrations(uid, token, segments)


def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, codec: str, channels: int):
def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, channels: int):
# 'wss://api.deepgram.com/v1/listen?encoding=linear16&sample_rate=8000&language=$recordingsLanguage&model=nova-2-general&no_delay=true&endpointing=100&interim_results=false&smart_format=true&diarize=true'
try:
dg_connection = deepgram.listen.websocket.v("1")
Expand Down Expand Up @@ -201,7 +201,7 @@ def on_unhandled(self, unhandled, **kwargs):
multichannel=channels > 1,
model='nova-2-general',
sample_rate=sample_rate,
encoding='linear16' if codec == 'pcm8' or codec == 'pcm16' else 'opus'
encoding='linear16'
)
result = dg_connection.start(options)
print('Deepgram connection started:', result)
Expand All @@ -213,10 +213,7 @@ def on_unhandled(self, unhandled, **kwargs):
soniox_valid_languages = ['en']


# soniox_valid_languages = ['en', 'es', 'fr', 'ko', 'zh', 'it', 'pt', 'de']


async def process_audio_soniox(stream_transcript, stream_id: int, language: str, uid: str):
async def process_audio_soniox(stream_transcript, stream_id: int, sample_rate: int, language: str, uid: str):
# Fuck, soniox doesn't even support diarization in languages != english
api_key = os.getenv('SONIOX_API_KEY')
if not api_key:
Expand All @@ -228,12 +225,12 @@ async def process_audio_soniox(stream_transcript, stream_id: int, language: str,
if language not in soniox_valid_languages:
raise ValueError(f"Unsupported language '{language}'. Supported languages are: {soniox_valid_languages}")

has_speech_profile = create_user_speech_profile(uid) if uid else False # only english too
has_speech_profile = create_user_speech_profile(uid) if uid and sample_rate == 16000 else False # only english too

# Construct the initial request with all required and optional parameters
request = {
'api_key': api_key,
'sample_rate_hertz': 16000,
'sample_rate_hertz': sample_rate,
'include_nonfinal': True,
'enable_endpoint_detection': True,
'enable_streaming_speaker_diarization': True,
Expand Down Expand Up @@ -330,12 +327,9 @@ async def on_message():
CONNECTION_URL = f"wss://eu2.rt.speechmatics.com/v2"


async def process_audio_speechmatics(stream_transcript, stream_id: int, language: str, preseconds: int = 0):
# Create a transcription client
async def process_audio_speechmatics(stream_transcript, stream_id: int, sample_rate: int, language: str, preseconds: int = 0):
api_key = os.getenv('SPEECHMATICS_API_KEY')
uri = 'wss://eu2.rt.speechmatics.com/v2'
# Validate the language and construct the model name
# has_speech_profile = create_user_speech_profile(uid) # only english too

request = {
"message": "StartRecognition",
Expand All @@ -349,7 +343,7 @@ async def process_audio_speechmatics(stream_transcript, stream_id: int, language
"enable_entities": True,
"speaker_diarization_config": {"max_speakers": 4}
},
"audio_format": {"type": "raw", "encoding": "pcm_s16le", "sample_rate": 16000},
"audio_format": {"type": "raw", "encoding": "pcm_s16le", "sample_rate": sample_rate},
# "audio_events_config": {
# "types": [
# "laughter",
Expand All @@ -359,16 +353,13 @@ async def process_audio_speechmatics(stream_transcript, stream_id: int, language
# }
}
try:
# Connect to Soniox WebSocket
print("Connecting to Speechmatics WebSocket...")
socket = await websockets.connect(uri, extra_headers={"Authorization": f"Bearer {api_key}"})
print("Connected to Speechmatics WebSocket.")

# Send the initial request
await socket.send(json.dumps(request))
print(f"Sent initial request: {request}")

# Start listening for messages from Soniox
async def on_message():
try:
async for message in socket:
Expand Down Expand Up @@ -400,7 +391,7 @@ async def on_message():

is_user = True if r_speaker == '1' and preseconds > 0 else False
if r_start < preseconds:
print('Skipping word', r_start, r_content)
# print('Skipping word', r_start, r_content)
continue
# print(r_content, r_speaker, [r_start, r_end])
if not segments:
Expand Down

0 comments on commit 4592649

Please sign in to comment.