-
Notifications
You must be signed in to change notification settings - Fork 400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix ws VAD for codec Opus, pcm8, pcm16 #565
Changes from 7 commits
b3fd503
5b60966
11745bc
2bef541
c758076
91c8d5d
70d1767
a6b8b81
3f40a43
465cb34
25f30d2
52fd340
2f0f3d8
4aa433f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,10 +7,13 @@ | |
from fastapi.websockets import (WebSocketDisconnect, WebSocket) | ||
from pydub import AudioSegment | ||
from starlette.websockets import WebSocketState | ||
import torch | ||
from collections import deque | ||
import opuslib | ||
|
||
from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration | ||
from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram | ||
from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty | ||
from utils.stt.vad import VADIterator, model, vad_is_empty, is_speech_present | ||
|
||
router = APIRouter() | ||
|
||
|
@@ -50,6 +53,9 @@ async def _websocket_util( | |
transcript_socket2 = None | ||
websocket_active = True | ||
duration = 0 | ||
is_speech_active = False | ||
speech_timeout = 0.7 # Good for now but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. | ||
last_speech_time = 0 | ||
try: | ||
if language == 'en' and codec == 'opus' and include_speech_profile: | ||
speech_profile = get_user_speech_profile(uid) | ||
|
@@ -71,58 +77,67 @@ async def _websocket_util( | |
await websocket.close() | ||
return | ||
|
||
vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9 | ||
threshold = 0.6 | ||
vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) | ||
window_size_samples = 256 if sample_rate == 8000 else 512 | ||
if codec == 'opus': | ||
decoder = opuslib.Decoder(sample_rate, channels) | ||
|
||
async def receive_audio(socket1, socket2): | ||
nonlocal websocket_active | ||
audio_buffer = bytearray() | ||
nonlocal is_speech_active, last_speech_time, decoder, websocket_active | ||
audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs | ||
databuffer = bytearray(b"") | ||
|
||
REALTIME_RESOLUTION = 0.01 | ||
sample_width = 2 # pcm8/16 here is 16 bit | ||
byte_rate = sample_width * sample_rate * channels | ||
chunk_size = int(byte_rate * REALTIME_RESOLUTION) | ||
|
||
timer_start = time.time() | ||
speech_state = SpeechState.no_speech | ||
voice_found, not_voice = 0, 0 | ||
# path = 'scripts/vad/audio_bytes.txt' | ||
# if os.path.exists(path): | ||
# os.remove(path) | ||
# audio_file = open(path, "a") | ||
try: | ||
while websocket_active: | ||
data = await websocket.receive_bytes() | ||
audio_buffer.extend(data) | ||
|
||
if codec == 'pcm8': | ||
frame_size, frames_count = 160, 16 | ||
if len(audio_buffer) < (frame_size * frames_count): | ||
continue | ||
|
||
latest_speech_state = get_speech_state( | ||
audio_buffer[:window_size_samples * 10], vad_iterator, window_size_samples | ||
) | ||
if latest_speech_state: | ||
speech_state = latest_speech_state | ||
|
||
if (voice_found or not_voice) and (voice_found + not_voice) % 100 == 0: | ||
print(uid, '\t', str(int((voice_found / (voice_found + not_voice)) * 100)) + '% \thas voice.') | ||
|
||
if speech_state == SpeechState.no_speech: | ||
not_voice += 1 | ||
# audio_buffer = bytearray() | ||
# continue | ||
if codec == 'opus': | ||
decoded_opus = decoder.decode(data, frame_size=320) | ||
samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 | ||
elif codec in ['pcm8', 'pcm16']: # Both are 16 bit | ||
writable_data = bytearray(data) | ||
samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0 | ||
else: | ||
raise ValueError(f"Unsupported codec: {codec}") | ||
|
||
audio_buffer.extend(samples) | ||
if len(audio_buffer) >= window_size_samples: | ||
tensor_audio = torch.tensor(list(audio_buffer)) | ||
if is_speech_present(tensor_audio[-window_size_samples:], vad_iterator, window_size_samples): | ||
if not is_speech_active: | ||
print('+Detected speech') | ||
is_speech_active = True | ||
last_speech_time = time.time() | ||
elif is_speech_active: | ||
if time.time() - last_speech_time > speech_timeout: | ||
is_speech_active = False | ||
# Clear only happens after the speech timeout | ||
audio_buffer.clear() | ||
print('-NO Detected speech') | ||
continue | ||
else: | ||
# audio_file.write(audio_buffer.hex() + "\n") | ||
voice_found += 1 | ||
|
||
continue | ||
|
||
elapsed_seconds = time.time() - timer_start | ||
if elapsed_seconds > duration or not socket2: | ||
socket1.send(audio_buffer) | ||
databuffer.extend(data) | ||
if len(databuffer) >= chunk_size: | ||
socket1.send(databuffer[:len(databuffer) - len(databuffer) % chunk_size]) | ||
databuffer = databuffer[len(databuffer) - len(databuffer) % chunk_size:] | ||
await asyncio.sleep(REALTIME_RESOLUTION) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sleep here is not perfect and can sleep for more or less than REALTIME_RESOLUTION and can cause significant drift over time especially with a low value like 10ms (100x per second) Here is an example how you could solve/offset the issue https://github.com/deepgram/median-streaming-latency/blob/main/latency.py#L78-L91 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah for low spec server that surely would cost us the drift. I'll fix that part, and thank you for the reference code! :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If sleep is 1ms off every 10ms then you would drift 100ms per second or 6 sec per minute I have seen the drift go as high as 60sec per minute and essentially cause you to stream audio very slowly at half real time speeds |
||
if socket2: | ||
print('Killing socket2') | ||
socket2.finish() | ||
socket2 = None | ||
else: | ||
socket2.send(audio_buffer) | ||
|
||
audio_buffer = bytearray() | ||
|
||
except WebSocketDisconnect: | ||
print("WebSocket disconnected") | ||
except Exception as e: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
import database.notifications as notification_db | ||
from utils.plugins import trigger_realtime_integrations | ||
import numpy as np | ||
|
||
headers = { | ||
"Authorization": f"Token {os.getenv('DEEPGRAM_API_KEY')}", | ||
|
@@ -103,6 +104,7 @@ async def process_audio_dg( | |
def on_message(self, result, **kwargs): | ||
# print(f"Received message from Deepgram") # Log when message is received | ||
sentence = result.channel.alternatives[0].transcript | ||
print(sentence) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will spam the logs |
||
if len(sentence) == 0: | ||
return | ||
# print(sentence) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,7 @@ | |
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils | ||
|
||
|
||
class SpeechState(str, Enum): | ||
has_speech = 'has_speech' | ||
no_speech = 'no_speech' | ||
|
||
|
||
def get_speech_state(data, vad_iterator, window_size_samples=256): | ||
has_start, has_end = False, False | ||
def is_speech_present(data, vad_iterator, window_size_samples=256): | ||
for i in range(0, len(data), window_size_samples): | ||
chunk = data[i: i + window_size_samples] | ||
if len(chunk) < window_size_samples: | ||
|
@@ -35,33 +29,51 @@ def get_speech_state(data, vad_iterator, window_size_samples=256): | |
# maybe like, if `end` was last, then return end? TEST THIS | ||
|
||
if speech_dict: | ||
vad_iterator.reset_states() | ||
josancamon19 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the reset_states be called even if speech_dict is false? |
||
# print(speech_dict) | ||
if 'start' in speech_dict: | ||
has_start = True | ||
elif 'end' in speech_dict: | ||
has_end = True | ||
# print('----') | ||
if has_start: | ||
return SpeechState.has_speech | ||
elif has_end: | ||
return SpeechState.no_speech | ||
return None | ||
|
||
# for i in range(0, len(data), window_size_samples): | ||
# chunk = data[i: i + window_size_samples] | ||
# if len(chunk) < window_size_samples: | ||
# break | ||
# speech_dict = vad_iterator(chunk, return_seconds=False) | ||
# if speech_dict: | ||
# print(speech_dict) | ||
# # how many times this triggers? | ||
# if 'start' in speech_dict: | ||
# return SpeechState.has_speech | ||
# elif 'end' in speech_dict: | ||
# return SpeechState.no_speech | ||
# return None | ||
return True | ||
vad_iterator.reset_states() | ||
return False | ||
|
||
|
||
def voice_in_bytes(data): | ||
# Convert audio bytes to a numpy array | ||
audio_array = np.frombuffer(data, dtype=np.int16) | ||
|
||
# Normalize audio to range [-1, 1] | ||
audio_tensor = torch.from_numpy(audio_array).float() / 32768.0 | ||
|
||
# Ensure the audio is in the correct shape (batch_size, num_channels, num_samples) | ||
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) | ||
|
||
# Pass the audio tensor to the VAD model | ||
speech_timestamps = get_speech_timestamps(audio_tensor, model) | ||
|
||
# Check if there's voice in the audio | ||
if speech_timestamps: | ||
print("Voice detected in the audio.") | ||
else: | ||
print("No voice detected in the audio.") | ||
|
||
|
||
# | ||
# | ||
# def speech_probabilities(file_path): | ||
# SAMPLING_RATE = 8000 | ||
# vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE) | ||
# wav = read_audio(file_path, sampling_rate=SAMPLING_RATE) | ||
# speech_probs = [] | ||
# window_size_samples = 512 if SAMPLING_RATE == 16000 else 256 | ||
# for i in range(0, len(wav), window_size_samples): | ||
# chunk = wav[i: i + window_size_samples] | ||
# if len(chunk) < window_size_samples: | ||
# break | ||
# speech_prob = model(chunk, SAMPLING_RATE).item() | ||
# speech_probs.append(speech_prob) | ||
# vad_iterator.reset_states() # reset model states after each audio | ||
# print(speech_probs[:10]) # first 10 chunks predicts | ||
# | ||
# | ||
@timeit | ||
def is_audio_empty(file_path, sample_rate=8000): | ||
wav = read_audio(file_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use 10ms chunks here?
20ms seems like a more standard size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my testing, I have not found the practical difference between them. I choose 10ms because more responsive would be good. Deepgram still buffer them until a good transcription is detected, right? Although 20ms would mean 2x less times of sending through the DG socket, I don't think that would increase the cost to the DG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no increase or decrease in cost of sending audio chunks but 10ms is very low and not recomended
20ms is the recommended minimum.
The server also has very high CPU usage with 10ms when multiple streams are running.
The receiver thread will be blocked by sender threads too eg if you have 100 connections all doing work 10ms apart
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice insight. Generally we should do stress test but I think going for the standard for baseline is never wrong. I'll make the required changes soon. Thanks!