Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ws VAD for codec Opus, pcm8, pcm16 #565

Merged
merged 14 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 72 additions & 38 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from fastapi import APIRouter
from fastapi.websockets import (WebSocketDisconnect, WebSocket)
from starlette.websockets import WebSocketState
import torch
from collections import deque
# import opuslib

from database.redis_db import get_user_speech_profile, get_user_speech_profile_duration
from utils.stt.streaming import process_audio_dg, send_initial_file
from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState
from utils.stt.vad import VADIterator, model, is_speech_present

router = APIRouter()

Expand Down Expand Up @@ -47,6 +50,9 @@ async def _websocket_util(
transcript_socket2 = None
websocket_active = True
duration = 0
is_speech_active = False
speech_timeout = 2.0 # Good for now (who doesnt like integer) but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet.
last_speech_time = 0
try:
if language == 'en' and codec == 'opus' and include_speech_profile:
speech_profile = get_user_speech_profile(uid)
Expand All @@ -68,58 +74,87 @@ async def _websocket_util(
await websocket.close()
return

vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9
threshold = 0.7
vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold)
window_size_samples = 256 if sample_rate == 8000 else 512
# if codec == 'opus':
# decoder = opuslib.Decoder(sample_rate, channels)

async def receive_audio(socket1, socket2):
nonlocal websocket_active
audio_buffer = bytearray()
nonlocal is_speech_active, last_speech_time, websocket_active
# nonlocal decoder

REALTIME_RESOLUTION = 0.01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use 10ms chunks here?

20ms seems like a more standard size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my testing, I have not found the practical difference between them. I choose 10ms because more responsive would be good. Deepgram still buffer them until a good transcription is detected, right? Although 20ms would mean 2x less times of sending through the DG socket, I don't think that would increase the cost to the DG.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no increase or decrease in cost of sending audio chunks but 10ms is very low and not recomended

20ms is the recommended minimum.

The server also has very high CPU usage with 10ms when multiple streams are running.

The receiver thread will be blocked by sender threads too eg if you have 100 connections all doing work 10ms apart

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice insight. Generally we should do stress test but I think going for the standard for baseline is never wrong. I'll make the required changes soon. Thanks!

sample_width = 2 # pcm8/16 here is 16 bit
byte_rate = sample_width * sample_rate * channels
chunk_size = int(byte_rate * REALTIME_RESOLUTION)
audio_buffer = deque(maxlen=byte_rate * 1) # 1 secs
databuffer = bytearray(b"")
prespeech_audio = deque(maxlen=int(byte_rate * 0.5)) # Queue of audio that will included to data (sent to DG) when is_speech_active become True

timer_start = time.time()
speech_state = SpeechState.no_speech
voice_found, not_voice = 0, 0
# path = 'scripts/vad/audio_bytes.txt'
# if os.path.exists(path):
# os.remove(path)
# audio_file = open(path, "a")
audio_cursor = 0 # For sleep realtime logic
try:
while websocket_active:
data = await websocket.receive_bytes()
audio_buffer.extend(data)

if codec == 'pcm8':
frame_size, frames_count = 160, 16
if len(audio_buffer) < (frame_size * frames_count):
continue

latest_speech_state = get_speech_state(
audio_buffer[:window_size_samples * 10], vad_iterator, window_size_samples
)
if latest_speech_state:
speech_state = latest_speech_state

if (voice_found or not_voice) and (voice_found + not_voice) % 100 == 0:
print(uid, '\t', str(int((voice_found / (voice_found + not_voice)) * 100)) + '% \thas voice.')

if speech_state == SpeechState.no_speech:
not_voice += 1
# audio_buffer = bytearray()
# continue
else:
# audio_file.write(audio_buffer.hex() + "\n")
voice_found += 1

recv_time = time.time()
if codec == 'opus':
# decoded_opus = decoder.decode(data, frame_size=320)
# samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0
pass
elif codec in ['pcm8', 'pcm16']: # Both are 16 bit
writable_data = bytearray(data)
samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0
else:
raise ValueError(f"Unsupported codec: {codec}")
# FIXME: opuslib is not working, so we are not using it
if codec != 'opus':
audio_buffer.extend(samples)
if len(audio_buffer) >= window_size_samples:
tensor_audio = torch.tensor(list(audio_buffer))
# Good alr, but increase the window size to get wider context but server will be slower
if is_speech_present(tensor_audio[-window_size_samples * 4:], vad_iterator, window_size_samples):
if not is_speech_active:
for audio in prespeech_audio:
databuffer.extend(audio.int().numpy().tobytes())
prespeech_audio.clear()
print('+Detected speech')
is_speech_active = True
last_speech_time = time.time()
elif is_speech_active:
if recv_time - last_speech_time > speech_timeout:
is_speech_active = False
# Reset only happens after the speech timeout
# Reason : Better to carry vad context for a speech, then reset for any new speech
vad_iterator.reset_states()
prespeech_audio.extend(samples)
print('-NO Detected speech')
continue
else:
prespeech_audio.extend(samples)
continue

elapsed_seconds = time.time() - timer_start
if elapsed_seconds > duration or not socket2:
socket1.send(audio_buffer)
databuffer.extend(data)
if len(databuffer) >= chunk_size or codec == 'opus':
# Sleep logic, because naive sleep is not accurate
current_time = time.time()
elapsed_time = current_time - timer_start
if elapsed_time < audio_cursor + REALTIME_RESOLUTION:
sleep_time = (audio_cursor + REALTIME_RESOLUTION) - elapsed_time
await asyncio.sleep(sleep_time)
# Just send them all, no difference
socket1.send(databuffer)
databuffer = bytearray(b"")
audio_cursor += REALTIME_RESOLUTION
if socket2:
print('Killing socket2')
socket2.finish()
socket2 = None
else:
socket2.send(audio_buffer)

audio_buffer = bytearray()

except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
Expand All @@ -135,7 +170,6 @@ async def send_heartbeat():
try:
while websocket_active:
await asyncio.sleep(30)
# print('send_heartbeat')
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
Expand Down
2 changes: 2 additions & 0 deletions backend/utils/stt/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import database.notifications as notification_db
from utils.plugins import trigger_realtime_integrations
import numpy as np

headers = {
"Authorization": f"Token {os.getenv('DEEPGRAM_API_KEY')}",
Expand Down Expand Up @@ -87,6 +88,7 @@ async def process_audio_dg(
def on_message(self, result, **kwargs):
# print(f"Received message from Deepgram") # Log when message is received
sentence = result.channel.alternatives[0].transcript
# print(sentence)
if len(sentence) == 0:
return
# print(sentence)
Expand Down
21 changes: 3 additions & 18 deletions backend/utils/stt/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils


class SpeechState(str, Enum):
has_speech = 'has_speech'
no_speech = 'no_speech'


def get_speech_state(data, vad_iterator, window_size_samples=256):
has_start, has_end = False, False
def is_speech_present(data, vad_iterator, window_size_samples=256):
for i in range(0, len(data), window_size_samples):
chunk = data[i: i + window_size_samples]
if len(chunk) < window_size_samples:
Expand All @@ -29,17 +23,8 @@ def get_speech_state(data, vad_iterator, window_size_samples=256):

if speech_dict:
# print(speech_dict)
if 'start' in speech_dict:
has_start = True
elif 'end' in speech_dict:
has_end = True
# print('----')
if has_start:
return SpeechState.has_speech
elif has_end:
return SpeechState.no_speech
return None

return True
return False

@timeit
def is_audio_empty(file_path, sample_rate=8000):
Expand Down
Loading