Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ws VAD for codec Opus, pcm8, pcm16 #565

Merged
merged 14 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 51 additions & 36 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from fastapi.websockets import (WebSocketDisconnect, WebSocket)
from pydub import AudioSegment
from starlette.websockets import WebSocketState
import torch
from collections import deque
import opuslib

from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration
from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram
from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty
from utils.stt.vad import VADIterator, model, vad_is_empty, is_speech_present

router = APIRouter()

Expand Down Expand Up @@ -50,6 +53,9 @@ async def _websocket_util(
transcript_socket2 = None
websocket_active = True
duration = 0
is_speech_active = False
speech_timeout = 0.7 # Good for now but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet.
last_speech_time = 0
try:
if language == 'en' and codec == 'opus' and include_speech_profile:
speech_profile = get_user_speech_profile(uid)
Expand All @@ -71,58 +77,67 @@ async def _websocket_util(
await websocket.close()
return

vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9
threshold = 0.6
vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold)
window_size_samples = 256 if sample_rate == 8000 else 512
if codec == 'opus':
decoder = opuslib.Decoder(sample_rate, channels)

async def receive_audio(socket1, socket2):
nonlocal websocket_active
audio_buffer = bytearray()
nonlocal is_speech_active, last_speech_time, decoder, websocket_active
audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs
databuffer = bytearray(b"")

REALTIME_RESOLUTION = 0.01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use 10ms chunks here?

20ms seems like a more standard size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my testing, I have not found the practical difference between them. I choose 10ms because more responsive would be good. Deepgram still buffer them until a good transcription is detected, right? Although 20ms would mean 2x less times of sending through the DG socket, I don't think that would increase the cost to the DG.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no increase or decrease in cost of sending audio chunks but 10ms is very low and not recomended

20ms is the recommended minimum.

The server also has very high CPU usage with 10ms when multiple streams are running.

The receiver thread will be blocked by sender threads too eg if you have 100 connections all doing work 10ms apart

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice insight. Generally we should do stress test but I think going for the standard for baseline is never wrong. I'll make the required changes soon. Thanks!

sample_width = 2 # pcm8/16 here is 16 bit
byte_rate = sample_width * sample_rate * channels
chunk_size = int(byte_rate * REALTIME_RESOLUTION)

timer_start = time.time()
speech_state = SpeechState.no_speech
voice_found, not_voice = 0, 0
# path = 'scripts/vad/audio_bytes.txt'
# if os.path.exists(path):
# os.remove(path)
# audio_file = open(path, "a")
try:
while websocket_active:
data = await websocket.receive_bytes()
audio_buffer.extend(data)

if codec == 'pcm8':
frame_size, frames_count = 160, 16
if len(audio_buffer) < (frame_size * frames_count):
continue

latest_speech_state = get_speech_state(
audio_buffer[:window_size_samples * 10], vad_iterator, window_size_samples
)
if latest_speech_state:
speech_state = latest_speech_state

if (voice_found or not_voice) and (voice_found + not_voice) % 100 == 0:
print(uid, '\t', str(int((voice_found / (voice_found + not_voice)) * 100)) + '% \thas voice.')

if speech_state == SpeechState.no_speech:
not_voice += 1
# audio_buffer = bytearray()
# continue
if codec == 'opus':
decoded_opus = decoder.decode(data, frame_size=320)
samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0
elif codec in ['pcm8', 'pcm16']: # Both are 16 bit
writable_data = bytearray(data)
samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0
else:
raise ValueError(f"Unsupported codec: {codec}")

audio_buffer.extend(samples)
if len(audio_buffer) >= window_size_samples:
tensor_audio = torch.tensor(list(audio_buffer))
if is_speech_present(tensor_audio[-window_size_samples:], vad_iterator, window_size_samples):
if not is_speech_active:
print('+Detected speech')
is_speech_active = True
last_speech_time = time.time()
elif is_speech_active:
if time.time() - last_speech_time > speech_timeout:
is_speech_active = False
# Clear only happens after the speech timeout
audio_buffer.clear()
print('-NO Detected speech')
continue
else:
# audio_file.write(audio_buffer.hex() + "\n")
voice_found += 1

continue

elapsed_seconds = time.time() - timer_start
if elapsed_seconds > duration or not socket2:
socket1.send(audio_buffer)
databuffer.extend(data)
if len(databuffer) >= chunk_size:
socket1.send(databuffer[:len(databuffer) - len(databuffer) % chunk_size])
databuffer = databuffer[len(databuffer) - len(databuffer) % chunk_size:]
await asyncio.sleep(REALTIME_RESOLUTION)
Copy link
Contributor

@DamienDeepgram DamienDeepgram Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sleep here is not perfect and can sleep for more or less than REALTIME_RESOLUTION and can cause significant drift over time especially with a low value like 10ms (100x per second)

Here is an example how you could solve/offset the issue

https://github.com/deepgram/median-streaming-latency/blob/main/latency.py#L78-L91

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for low spec server that surely would cost us the drift. I'll fix that part, and thank you for the reference code! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sleep is 1ms off every 10ms then you would drift 100ms per second or 6 sec per minute

I have seen the drift go as high as 60sec per minute and essentially cause you to stream audio very slowly at half real time speeds

if socket2:
print('Killing socket2')
socket2.finish()
socket2 = None
else:
socket2.send(audio_buffer)

audio_buffer = bytearray()

except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions backend/utils/stt/deepgram_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import database.notifications as notification_db
from utils.plugins import trigger_realtime_integrations
import numpy as np

headers = {
"Authorization": f"Token {os.getenv('DEEPGRAM_API_KEY')}",
Expand Down Expand Up @@ -103,6 +104,7 @@ async def process_audio_dg(
def on_message(self, result, **kwargs):
# print(f"Received message from Deepgram") # Log when message is received
sentence = result.channel.alternatives[0].transcript
print(sentence)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will spam the logs

if len(sentence) == 0:
return
# print(sentence)
Expand Down
74 changes: 43 additions & 31 deletions backend/utils/stt/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils


class SpeechState(str, Enum):
has_speech = 'has_speech'
no_speech = 'no_speech'


def get_speech_state(data, vad_iterator, window_size_samples=256):
has_start, has_end = False, False
def is_speech_present(data, vad_iterator, window_size_samples=256):
for i in range(0, len(data), window_size_samples):
chunk = data[i: i + window_size_samples]
if len(chunk) < window_size_samples:
Expand All @@ -35,33 +29,51 @@ def get_speech_state(data, vad_iterator, window_size_samples=256):
# maybe like, if `end` was last, then return end? TEST THIS

if speech_dict:
vad_iterator.reset_states()
josancamon19 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the reset_states be called even if speech_dict is false?

# print(speech_dict)
if 'start' in speech_dict:
has_start = True
elif 'end' in speech_dict:
has_end = True
# print('----')
if has_start:
return SpeechState.has_speech
elif has_end:
return SpeechState.no_speech
return None

# for i in range(0, len(data), window_size_samples):
# chunk = data[i: i + window_size_samples]
# if len(chunk) < window_size_samples:
# break
# speech_dict = vad_iterator(chunk, return_seconds=False)
# if speech_dict:
# print(speech_dict)
# # how many times this triggers?
# if 'start' in speech_dict:
# return SpeechState.has_speech
# elif 'end' in speech_dict:
# return SpeechState.no_speech
# return None
return True
vad_iterator.reset_states()
return False


def voice_in_bytes(data):
# Convert audio bytes to a numpy array
audio_array = np.frombuffer(data, dtype=np.int16)

# Normalize audio to range [-1, 1]
audio_tensor = torch.from_numpy(audio_array).float() / 32768.0

# Ensure the audio is in the correct shape (batch_size, num_channels, num_samples)
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)

# Pass the audio tensor to the VAD model
speech_timestamps = get_speech_timestamps(audio_tensor, model)

# Check if there's voice in the audio
if speech_timestamps:
print("Voice detected in the audio.")
else:
print("No voice detected in the audio.")


#
#
# def speech_probabilities(file_path):
# SAMPLING_RATE = 8000
# vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)
# wav = read_audio(file_path, sampling_rate=SAMPLING_RATE)
# speech_probs = []
# window_size_samples = 512 if SAMPLING_RATE == 16000 else 256
# for i in range(0, len(wav), window_size_samples):
# chunk = wav[i: i + window_size_samples]
# if len(chunk) < window_size_samples:
# break
# speech_prob = model(chunk, SAMPLING_RATE).item()
# speech_probs.append(speech_prob)
# vad_iterator.reset_states() # reset model states after each audio
# print(speech_probs[:10]) # first 10 chunks predicts
#
#
@timeit
def is_audio_empty(file_path, sample_rate=8000):
wav = read_audio(file_path)
Expand Down
Loading