Skip to content

Commit

Permalink
fix ws VAD for codec Opus, pcm8, pcm16 (#565)
Browse files Browse the repository at this point in the history
**Related issues:**
- Resolves #518 partially.

**Overview:**
This PR improved Voice Activity Detection (VAD) for ws stream,
specifically targeting Opus, PCM8, and PCM16 audio formats. It fixes
false negatives and false positives for those codecs on many languages

**What has changed:**
- VAD Integration: The 'silero-vad' library has been reintegrated again
- Buffer Management: Uses a deque to manage an audio buffer that
accommodates up to 1 seconds of audio. This buffer is crucial for
ensuring that VAD has enough samples to make accurate decisions.
- Audio processing : The code differentiates handling based on the codec
and applies VAD to identify active speech.

**Testing:**
- Manual testing conducted with variation in codecs, and languages. But
for languages the testing is not that deep, only in en, chinese, indo
yet.
  • Loading branch information
mdmohsin7 committed Aug 19, 2024
2 parents 7ca92b0 + 4aa433f commit bdd6c71
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 56 deletions.
110 changes: 72 additions & 38 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from fastapi import APIRouter
from fastapi.websockets import (WebSocketDisconnect, WebSocket)
from starlette.websockets import WebSocketState
import torch
from collections import deque
# import opuslib

from database.redis_db import get_user_speech_profile, get_user_speech_profile_duration
from utils.stt.streaming import process_audio_dg, send_initial_file
from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState
from utils.stt.vad import VADIterator, model, is_speech_present

router = APIRouter()

Expand Down Expand Up @@ -47,6 +50,9 @@ async def _websocket_util(
transcript_socket2 = None
websocket_active = True
duration = 0
is_speech_active = False
speech_timeout = 2.0 # Good for now (who doesnt like integer) but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet.
last_speech_time = 0
try:
if language == 'en' and codec == 'opus' and include_speech_profile:
speech_profile = get_user_speech_profile(uid)
Expand All @@ -68,58 +74,87 @@ async def _websocket_util(
await websocket.close()
return

vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9
threshold = 0.7
vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold)
window_size_samples = 256 if sample_rate == 8000 else 512
# if codec == 'opus':
# decoder = opuslib.Decoder(sample_rate, channels)

async def receive_audio(socket1, socket2):
nonlocal websocket_active
audio_buffer = bytearray()
nonlocal is_speech_active, last_speech_time, websocket_active
# nonlocal decoder

REALTIME_RESOLUTION = 0.01
sample_width = 2 # pcm8/16 here is 16 bit
byte_rate = sample_width * sample_rate * channels
chunk_size = int(byte_rate * REALTIME_RESOLUTION)
audio_buffer = deque(maxlen=byte_rate * 1) # 1 secs
databuffer = bytearray(b"")
prespeech_audio = deque(maxlen=int(byte_rate * 0.5)) # Queue of audio that will included to data (sent to DG) when is_speech_active become True

timer_start = time.time()
speech_state = SpeechState.no_speech
voice_found, not_voice = 0, 0
# path = 'scripts/vad/audio_bytes.txt'
# if os.path.exists(path):
# os.remove(path)
# audio_file = open(path, "a")
audio_cursor = 0 # For sleep realtime logic
try:
while websocket_active:
data = await websocket.receive_bytes()
audio_buffer.extend(data)

if codec == 'pcm8':
frame_size, frames_count = 160, 16
if len(audio_buffer) < (frame_size * frames_count):
continue

latest_speech_state = get_speech_state(
audio_buffer[:window_size_samples * 10], vad_iterator, window_size_samples
)
if latest_speech_state:
speech_state = latest_speech_state

if (voice_found or not_voice) and (voice_found + not_voice) % 100 == 0:
print(uid, '\t', str(int((voice_found / (voice_found + not_voice)) * 100)) + '% \thas voice.')

if speech_state == SpeechState.no_speech:
not_voice += 1
# audio_buffer = bytearray()
# continue
else:
# audio_file.write(audio_buffer.hex() + "\n")
voice_found += 1

recv_time = time.time()
if codec == 'opus':
# decoded_opus = decoder.decode(data, frame_size=320)
# samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0
pass
elif codec in ['pcm8', 'pcm16']: # Both are 16 bit
writable_data = bytearray(data)
samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0
else:
raise ValueError(f"Unsupported codec: {codec}")
# FIXME: opuslib is not working, so we are not using it
if codec != 'opus':
audio_buffer.extend(samples)
if len(audio_buffer) >= window_size_samples:
tensor_audio = torch.tensor(list(audio_buffer))
# Good alr, but increase the window size to get wider context but server will be slower
if is_speech_present(tensor_audio[-window_size_samples * 4:], vad_iterator, window_size_samples):
if not is_speech_active:
for audio in prespeech_audio:
databuffer.extend(audio.int().numpy().tobytes())
prespeech_audio.clear()
print('+Detected speech')
is_speech_active = True
last_speech_time = time.time()
elif is_speech_active:
if recv_time - last_speech_time > speech_timeout:
is_speech_active = False
# Reset only happens after the speech timeout
# Reason : Better to carry vad context for a speech, then reset for any new speech
vad_iterator.reset_states()
prespeech_audio.extend(samples)
print('-NO Detected speech')
continue
else:
prespeech_audio.extend(samples)
continue

elapsed_seconds = time.time() - timer_start
if elapsed_seconds > duration or not socket2:
socket1.send(audio_buffer)
databuffer.extend(data)
if len(databuffer) >= chunk_size or codec == 'opus':
# Sleep logic, because naive sleep is not accurate
current_time = time.time()
elapsed_time = current_time - timer_start
if elapsed_time < audio_cursor + REALTIME_RESOLUTION:
sleep_time = (audio_cursor + REALTIME_RESOLUTION) - elapsed_time
await asyncio.sleep(sleep_time)
# Just send them all, no difference
socket1.send(databuffer)
databuffer = bytearray(b"")
audio_cursor += REALTIME_RESOLUTION
if socket2:
print('Killing socket2')
socket2.finish()
socket2 = None
else:
socket2.send(audio_buffer)

audio_buffer = bytearray()

except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
Expand All @@ -135,7 +170,6 @@ async def send_heartbeat():
try:
while websocket_active:
await asyncio.sleep(30)
# print('send_heartbeat')
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
Expand Down
2 changes: 2 additions & 0 deletions backend/utils/stt/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import database.notifications as notification_db
from utils.plugins import trigger_realtime_integrations
import numpy as np

headers = {
"Authorization": f"Token {os.getenv('DEEPGRAM_API_KEY')}",
Expand Down Expand Up @@ -87,6 +88,7 @@ async def process_audio_dg(
def on_message(self, result, **kwargs):
# print(f"Received message from Deepgram") # Log when message is received
sentence = result.channel.alternatives[0].transcript
# print(sentence)
if len(sentence) == 0:
return
# print(sentence)
Expand Down
21 changes: 3 additions & 18 deletions backend/utils/stt/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils


class SpeechState(str, Enum):
has_speech = 'has_speech'
no_speech = 'no_speech'


def get_speech_state(data, vad_iterator, window_size_samples=256):
has_start, has_end = False, False
def is_speech_present(data, vad_iterator, window_size_samples=256):
for i in range(0, len(data), window_size_samples):
chunk = data[i: i + window_size_samples]
if len(chunk) < window_size_samples:
Expand All @@ -29,17 +23,8 @@ def get_speech_state(data, vad_iterator, window_size_samples=256):

if speech_dict:
# print(speech_dict)
if 'start' in speech_dict:
has_start = True
elif 'end' in speech_dict:
has_end = True
# print('----')
if has_start:
return SpeechState.has_speech
elif has_end:
return SpeechState.no_speech
return None

return True
return False

@timeit
def is_audio_empty(file_path, sample_rate=8000):
Expand Down

0 comments on commit bdd6c71

Please sign in to comment.