diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 2b7eb2fb7..7a96c76a5 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -4,10 +4,13 @@ from fastapi import APIRouter from fastapi.websockets import (WebSocketDisconnect, WebSocket) from starlette.websockets import WebSocketState +import torch +from collections import deque +# import opuslib from database.redis_db import get_user_speech_profile, get_user_speech_profile_duration from utils.stt.streaming import process_audio_dg, send_initial_file -from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState +from utils.stt.vad import VADIterator, model, is_speech_present router = APIRouter() @@ -47,6 +50,9 @@ async def _websocket_util( transcript_socket2 = None websocket_active = True duration = 0 + is_speech_active = False + speech_timeout = 2.0 # Good for now (who doesnt like integer) but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. + last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: speech_profile = get_user_speech_profile(uid) @@ -68,49 +74,80 @@ async def _websocket_util( await websocket.close() return - vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9 + threshold = 0.7 + vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 + # if codec == 'opus': + # decoder = opuslib.Decoder(sample_rate, channels) async def receive_audio(socket1, socket2): - nonlocal websocket_active - audio_buffer = bytearray() + nonlocal is_speech_active, last_speech_time, websocket_active + # nonlocal decoder + + REALTIME_RESOLUTION = 0.01 + sample_width = 2 # pcm8/16 here is 16 bit + byte_rate = sample_width * sample_rate * channels + chunk_size = int(byte_rate * REALTIME_RESOLUTION) + audio_buffer = deque(maxlen=byte_rate * 1) # 1 secs + databuffer = bytearray(b"") + prespeech_audio = deque(maxlen=int(byte_rate * 0.5)) # Queue of audio that will included to data (sent to DG) when is_speech_active become True + timer_start = time.time() - speech_state = SpeechState.no_speech - voice_found, not_voice = 0, 0 - # path = 'scripts/vad/audio_bytes.txt' - # if os.path.exists(path): - # os.remove(path) - # audio_file = open(path, "a") + audio_cursor = 0 # For sleep realtime logic try: while websocket_active: data = await websocket.receive_bytes() - audio_buffer.extend(data) - - if codec == 'pcm8': - frame_size, frames_count = 160, 16 - if len(audio_buffer) < (frame_size * frames_count): - continue - - latest_speech_state = get_speech_state( - audio_buffer[:window_size_samples * 10], vad_iterator, window_size_samples - ) - if latest_speech_state: - speech_state = latest_speech_state - - if (voice_found or not_voice) and (voice_found + not_voice) % 100 == 0: - print(uid, '\t', str(int((voice_found / (voice_found + not_voice)) * 100)) + '% \thas voice.') - - if speech_state == SpeechState.no_speech: - not_voice += 1 - # audio_buffer = bytearray() - # continue - else: - # audio_file.write(audio_buffer.hex() + "\n") - voice_found += 1 - + recv_time = time.time() + if codec == 'opus': + # decoded_opus = decoder.decode(data, frame_size=320) + # samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 + pass + elif codec in ['pcm8', 'pcm16']: # Both are 16 bit + writable_data = bytearray(data) + samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0 + else: + raise ValueError(f"Unsupported codec: {codec}") + # FIXME: opuslib is not working, so we are not using it + if codec != 'opus': + audio_buffer.extend(samples) + if len(audio_buffer) >= window_size_samples: + tensor_audio = torch.tensor(list(audio_buffer)) + # Good alr, but increase the window size to get wider context but server will be slower + if is_speech_present(tensor_audio[-window_size_samples * 4:], vad_iterator, window_size_samples): + if not is_speech_active: + for audio in prespeech_audio: + databuffer.extend(audio.int().numpy().tobytes()) + prespeech_audio.clear() + print('+Detected speech') + is_speech_active = True + last_speech_time = time.time() + elif is_speech_active: + if recv_time - last_speech_time > speech_timeout: + is_speech_active = False + # Reset only happens after the speech timeout + # Reason : Better to carry vad context for a speech, then reset for any new speech + vad_iterator.reset_states() + prespeech_audio.extend(samples) + print('-NO Detected speech') + continue + else: + prespeech_audio.extend(samples) + continue + elapsed_seconds = time.time() - timer_start if elapsed_seconds > duration or not socket2: - socket1.send(audio_buffer) + databuffer.extend(data) + if len(databuffer) >= chunk_size or codec == 'opus': + # Sleep logic, because naive sleep is not accurate + current_time = time.time() + elapsed_time = current_time - timer_start + if elapsed_time < audio_cursor + REALTIME_RESOLUTION: + sleep_time = (audio_cursor + REALTIME_RESOLUTION) - elapsed_time + await asyncio.sleep(sleep_time) + # Just send them all, no difference + socket1.send(databuffer) + databuffer = bytearray(b"") + audio_cursor += REALTIME_RESOLUTION if socket2: print('Killing socket2') socket2.finish() @@ -118,8 +155,6 @@ async def receive_audio(socket1, socket2): else: socket2.send(audio_buffer) - audio_buffer = bytearray() - except WebSocketDisconnect: print("WebSocket disconnected") except Exception as e: @@ -135,7 +170,6 @@ async def send_heartbeat(): try: while websocket_active: await asyncio.sleep(30) - # print('send_heartbeat') if websocket.client_state == WebSocketState.CONNECTED: await websocket.send_json({"type": "ping"}) else: diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index b0366b3f7..ca9661683 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -10,6 +10,7 @@ import database.notifications as notification_db from utils.plugins import trigger_realtime_integrations +import numpy as np headers = { "Authorization": f"Token {os.getenv('DEEPGRAM_API_KEY')}", @@ -87,6 +88,7 @@ async def process_audio_dg( def on_message(self, result, **kwargs): # print(f"Received message from Deepgram") # Log when message is received sentence = result.channel.alternatives[0].transcript + # print(sentence) if len(sentence) == 0: return # print(sentence) diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index 14961a5d7..86ff7529d 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -12,13 +12,7 @@ (get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils -class SpeechState(str, Enum): - has_speech = 'has_speech' - no_speech = 'no_speech' - - -def get_speech_state(data, vad_iterator, window_size_samples=256): - has_start, has_end = False, False +def is_speech_present(data, vad_iterator, window_size_samples=256): for i in range(0, len(data), window_size_samples): chunk = data[i: i + window_size_samples] if len(chunk) < window_size_samples: @@ -29,17 +23,8 @@ def get_speech_state(data, vad_iterator, window_size_samples=256): if speech_dict: # print(speech_dict) - if 'start' in speech_dict: - has_start = True - elif 'end' in speech_dict: - has_end = True - # print('----') - if has_start: - return SpeechState.has_speech - elif has_end: - return SpeechState.no_speech - return None - + return True + return False @timeit def is_audio_empty(file_path, sample_rate=8000):