Skip to content

Commit

Permalink
initial VAD restoration code
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Sep 23, 2024
1 parent 714090a commit 1032658
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 38 deletions.
6 changes: 5 additions & 1 deletion backend/models/message_event.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pydantic import BaseModel, Field
from typing import List, Optional

from pydantic import BaseModel

from models.memory import Memory, Message


class MessageEvent(BaseModel):
event_type: str

Expand All @@ -11,6 +13,7 @@ def to_json(self):
j["type"] = self.event_type
return j


class NewMemoryCreated(MessageEvent):
processing_memory_id: Optional[str] = None
memory_id: Optional[str] = None
Expand All @@ -23,6 +26,7 @@ def to_json(self):
j["type"] = self.event_type
return j


class NewProcessingMemoryCreated(MessageEvent):
processing_memory_id: Optional[str] = None

Expand Down
58 changes: 29 additions & 29 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from utils.memories.process_memory import process_memory
from utils.processing_memories import create_memory_by_processing_memory
from utils.stt.streaming import *
from utils.stt.vad import VADIterator, model
from utils.stt.vad import VADIterator, model, is_speech_present, SpeechState

router = APIRouter()

Expand Down Expand Up @@ -84,20 +84,15 @@ def get_model_name(value):
async def _websocket_util(
websocket: WebSocket, uid: str, language: str = 'en', sample_rate: int = 8000, codec: str = 'pcm8',
channels: int = 1, include_speech_profile: bool = True, new_memory_watch: bool = False,
stt_service: STTService = STTService.deepgram,
# stt_service: STTService = STTService.deepgram,
):
print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch,
stt_service)
print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch)

if stt_service == STTService.soniox and language not in soniox_valid_languages:
stt_service = STTService.deepgram # defaults to deepgram

if stt_service == STTService.speechmatics: # defaults to deepgram (no credits + 10 connections max limit)
if language == 'en':
stt_service = STTService.soniox
else:
stt_service = STTService.deepgram

# TODO: if language english, use soniox
# TODO: else deepgram, if speechmatics credits, prob this for both?

try:
await websocket.accept()
except RuntimeError as e:
Expand Down Expand Up @@ -241,6 +236,7 @@ async def deepgram_socket_send(data):

vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9
window_size_samples = 256 if sample_rate == 8000 else 512
window_size_bytes = int(window_size_samples * 2 * 2.5)

decoder = opuslib.Decoder(sample_rate, channels)

Expand All @@ -251,43 +247,47 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock
timer_start = time.time()

# nonlocal audio_buffer
# audio_buffer = bytearray()

# speech_state = SpeechState.no_speech
# voice_found, not_voice = 0, 0
# path = 'scripts/vad/audio_bytes.txt'
# if os.path.exists(path):
# os.remove(path)
# audio_file = open(path, "a")
audio_buffer = bytearray()
speech_state = SpeechState.no_speech

try:
while websocket_active:
raw_data = await websocket.receive_bytes()
data = raw_data[:]
# audio_buffer.extend(data)

if codec == 'opus' and sample_rate == 16000:
data = decoder.decode(bytes(data), frame_size=160)

audio_buffer.extend(data)
if len(audio_buffer) < window_size_bytes:
continue

speech_state = is_speech_present(audio_buffer[:window_size_bytes], vad_iterator, window_size_samples)

# if speech_state == SpeechState.no_speech:
# audio_buffer = audio_buffer[window_size_bytes:]
# continue

if soniox_socket is not None:
await soniox_socket.send(data)
await soniox_socket.send(audio_buffer)

if speechmatics_socket1 is not None:
await speechmatics_socket1.send(data)
await speechmatics_socket1.send(audio_buffer)

if deepgram_socket is not None:
elapsed_seconds = time.time() - timer_start
if elapsed_seconds > duration or not dg_socket2:
dg_socket1.send(data)
dg_socket1.send(audio_buffer)
if dg_socket2:
print('Killing socket2')
dg_socket2.finish()
dg_socket2 = None
else:
dg_socket2.send(data)
dg_socket2.send(audio_buffer)

# stream
stream_audio(raw_data)

# audio_buffer = bytearray()
stream_audio(audio_buffer)
audio_buffer = audio_buffer[window_size_bytes:]

except WebSocketDisconnect:
print("WebSocket disconnected")
Expand Down Expand Up @@ -588,8 +588,8 @@ async def _try_flush_new_memory(time_validate: bool = True):
async def websocket_endpoint(
websocket: WebSocket, uid: str, language: str = 'en', sample_rate: int = 8000, codec: str = 'pcm8',
channels: int = 1, include_speech_profile: bool = True, new_memory_watch: bool = False,
stt_service: STTService = STTService.deepgram
# stt_service: STTService = STTService.deepgram
):
await _websocket_util(
websocket, uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, stt_service
websocket, uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, # stt_service
)
37 changes: 29 additions & 8 deletions backend/utils/stt/vad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from enum import Enum

import numpy as np
import requests
import torch
from fastapi import HTTPException
Expand All @@ -11,19 +13,38 @@
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils


class SpeechState(str, Enum):
speech_found = 'speech_found'
no_speech = 'no_speech'


def is_speech_present(data, vad_iterator, window_size_samples=256):
for i in range(0, len(data), window_size_samples):
chunk = data[i: i + window_size_samples]
data_int16 = np.frombuffer(data, dtype=np.int16)
data_float32 = data_int16.astype(np.float32) / 32768.0
has_start, has_end = False, False

for i in range(0, len(data_float32), window_size_samples):
chunk = data_float32[i: i + window_size_samples]
if len(chunk) < window_size_samples:
break
speech_dict = vad_iterator(chunk, return_seconds=False)
# TODO: should have like a buffer of start? or some way to not keep it, it ends appear first
# maybe like, if `end` was last, then return end? TEST THIS

if speech_dict:
# print(speech_dict)
return True
return False
print(speech_dict)
vad_iterator.reset_states()
return SpeechState.speech_found

# if not has_start and 'start' in speech_dict:
# has_start = True
#
# if not has_end and 'end' in speech_dict:
# has_end = True

# if has_start:
# return SpeechState.speech_found
# elif has_end:
# return SpeechState.no_speech
vad_iterator.reset_states()
return SpeechState.no_speech


def is_audio_empty(file_path, sample_rate=8000):
Expand Down

0 comments on commit 1032658

Please sign in to comment.