From b3fd503981ec29f56207a56a538fcb72f120287e Mon Sep 17 00:00:00 2001 From: 0xzre Date: Sat, 10 Aug 2024 10:03:44 +0700 Subject: [PATCH 01/13] fix ws VAD for codec Opus, pcm8, pcm16 --- backend/main.py | 12 +++---- backend/routers/transcribe.py | 68 +++++++++++++++++++++-------------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/backend/main.py b/backend/main.py index 096b3c6d2..3a912ef99 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,12 +10,12 @@ from routers import chat, memories, plugins, speech_profile, transcribe, screenpipe, firmware, notifications, workflow from utils.crons.notifications import start_cron_job -if os.environ.get('SERVICE_ACCOUNT_JSON'): - service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) - credentials = firebase_admin.credentials.Certificate(service_account_info) - firebase_admin.initialize_app(credentials) -else: - firebase_admin.initialize_app() +# if os.environ.get('SERVICE_ACCOUNT_JSON'): +# service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) +# credentials = firebase_admin.credentials.Certificate(service_account_info) +# firebase_admin.initialize_app(credentials) +# else: +# firebase_admin.initialize_app() app = FastAPI() app.include_router(transcribe.router) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 8713cda22..062eedbc0 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -7,10 +7,13 @@ from fastapi.websockets import (WebSocketDisconnect, WebSocket) from pydub import AudioSegment from starlette.websockets import WebSocketState +import torch +import numpy as np +from collections import deque from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram -from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty +from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty, is_speech_present router = APIRouter() @@ -50,6 +53,9 @@ async def _websocket_util( transcript_socket2 = None websocket_active = True duration = 0 + is_speech_active = False + speech_timeout = 1.0 # Configurable even better if we can decade from user needs/behaviour + last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: speech_profile = get_user_speech_profile(uid) @@ -75,8 +81,8 @@ async def _websocket_util( window_size_samples = 256 if sample_rate == 8000 else 512 async def receive_audio(socket1, socket2): - nonlocal websocket_active - audio_buffer = bytearray() + audio_buffer = deque(maxlen=sample_rate * 3) # 3 secs + nonlocal is_speech_active, last_speech_time, websocket_active timer_start = time.time() speech_state = SpeechState.no_speech voice_found, not_voice = 0, 0 @@ -87,30 +93,40 @@ async def receive_audio(socket1, socket2): try: while websocket_active: data = await websocket.receive_bytes() - audio_buffer.extend(data) - - if codec == 'pcm8': - frame_size, frames_count = 160, 16 - if len(audio_buffer) < (frame_size * frames_count): - continue - - latest_speech_state = get_speech_state( - audio_buffer[:window_size_samples * 10], vad_iterator, window_size_samples - ) - if latest_speech_state: - speech_state = latest_speech_state - - if (voice_found or not_voice) and (voice_found + not_voice) % 100 == 0: - print(uid, '\t', str(int((voice_found / (voice_found + not_voice)) * 100)) + '% \thas voice.') - - if speech_state == SpeechState.no_speech: - not_voice += 1 - # audio_buffer = bytearray() - # continue + # print(len(data)) + if codec == 'opus': + audio = AudioSegment(data=data, sample_width=2, frame_rate=sample_rate, channels=channels) + samples = torch.tensor(audio.get_array_of_samples()).float() / 32768.0 + elif codec in ['pcm8', 'pcm16']: + dtype = torch.int8 if codec == 'pcm8' else torch.int16 + samples = torch.frombuffer(data, dtype=dtype).float() + samples = samples / (128.0 if codec == 'pcm8' else 32768.0) + else: + raise ValueError(f"Unsupported codec: {codec}") + + audio_buffer.extend(samples) + # print(len(audio_buffer), window_size_samples * 2) # * 2 because 16bit + # TODO: vad not working propperly. + # - PCM still has to collect samples, and while it collects them, still sends them to the socket, so it's like nothing + # - Opus always says there's no speech (but collection doesn't matter much, as it triggers like 1 per 0.2 seconds) + + # len(data) = 160, 8khz 16bit -> 2 bytes per sample, 80 samples, needs 256 samples, which is 256*2 bytes + if len(audio_buffer) >= window_size_samples * 2: + tensor_audio = torch.tensor(list(audio_buffer)) + if is_speech_present(tensor_audio, vad_iterator, window_size_samples): + print('+Detected speech') + is_speech_active = True + last_speech_time = time.time() + elif is_speech_active: + if time.time() - last_speech_time > speech_timeout: + is_speech_active = False + print('-NO Detected speech') + continue + print('+Detected speech') else: - # audio_file.write(audio_buffer.hex() + "\n") - voice_found += 1 - + print('-NO Detected speech') + continue + elapsed_seconds = time.time() - timer_start if elapsed_seconds > duration or not socket2: socket1.send(audio_buffer) From 5b609663f792ded7883a4e8bcc8c886925341ab6 Mon Sep 17 00:00:00 2001 From: 0xzre Date: Sat, 10 Aug 2024 12:42:34 +0700 Subject: [PATCH 02/13] revert main, fix buffer maxlen, add missing vad audio buffer clear --- backend/main.py | 12 ++++++------ backend/routers/transcribe.py | 8 +++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/backend/main.py b/backend/main.py index 3a912ef99..096b3c6d2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,12 +10,12 @@ from routers import chat, memories, plugins, speech_profile, transcribe, screenpipe, firmware, notifications, workflow from utils.crons.notifications import start_cron_job -# if os.environ.get('SERVICE_ACCOUNT_JSON'): -# service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) -# credentials = firebase_admin.credentials.Certificate(service_account_info) -# firebase_admin.initialize_app(credentials) -# else: -# firebase_admin.initialize_app() +if os.environ.get('SERVICE_ACCOUNT_JSON'): + service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) + credentials = firebase_admin.credentials.Certificate(service_account_info) + firebase_admin.initialize_app(credentials) +else: + firebase_admin.initialize_app() app = FastAPI() app.include_router(transcribe.router) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 062eedbc0..955e9b22d 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -81,7 +81,7 @@ async def _websocket_util( window_size_samples = 256 if sample_rate == 8000 else 512 async def receive_audio(socket1, socket2): - audio_buffer = deque(maxlen=sample_rate * 3) # 3 secs + audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs nonlocal is_speech_active, last_speech_time, websocket_active timer_start = time.time() speech_state = SpeechState.no_speech @@ -106,10 +106,6 @@ async def receive_audio(socket1, socket2): audio_buffer.extend(samples) # print(len(audio_buffer), window_size_samples * 2) # * 2 because 16bit - # TODO: vad not working propperly. - # - PCM still has to collect samples, and while it collects them, still sends them to the socket, so it's like nothing - # - Opus always says there's no speech (but collection doesn't matter much, as it triggers like 1 per 0.2 seconds) - # len(data) = 160, 8khz 16bit -> 2 bytes per sample, 80 samples, needs 256 samples, which is 256*2 bytes if len(audio_buffer) >= window_size_samples * 2: tensor_audio = torch.tensor(list(audio_buffer)) @@ -120,6 +116,8 @@ async def receive_audio(socket1, socket2): elif is_speech_active: if time.time() - last_speech_time > speech_timeout: is_speech_active = False + # Clear only happens after the speech timeout + audio_buffer.clear() print('-NO Detected speech') continue print('+Detected speech') From 11745bcc8c017d6033aa970c14eb86d5ae0bf41f Mon Sep 17 00:00:00 2001 From: 0xzre Date: Mon, 12 Aug 2024 04:52:44 +0700 Subject: [PATCH 03/13] fix pcm8 and pcm16 vad and for the dg transcription now works, missing opus still --- backend/routers/transcribe.py | 44 ++++++++++++++------ backend/utils/stt/deepgram_util.py | 9 ++++ backend/utils/stt/vad.py | 66 +++++++++++++++++++----------- 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 955e9b22d..8a05a5124 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -8,11 +8,10 @@ from pydub import AudioSegment from starlette.websockets import WebSocketState import torch -import numpy as np from collections import deque from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration -from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram +from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram, convert_pcm8_to_pcm16 from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty, is_speech_present router = APIRouter() @@ -54,7 +53,7 @@ async def _websocket_util( websocket_active = True duration = 0 is_speech_active = False - speech_timeout = 1.0 # Configurable even better if we can decade from user needs/behaviour + speech_timeout = 0.7 # Configurable even better if we can decade from user needs/behaviour last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: @@ -77,12 +76,26 @@ async def _websocket_util( await websocket.close() return - vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9 + threshold = 0.6 # Currently most fitting threshold + vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 async def receive_audio(socket1, socket2): - audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs nonlocal is_speech_active, last_speech_time, websocket_active + audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs + databuffer = bytearray(b"") + + REALTIME_RESOLUTION = 0.01 + if codec == 'opus': + sample_width = 2 + else: + sample_width = 1 + if sample_width: + byte_rate = sample_width * sample_rate * channels + chunk_size = int(byte_rate * REALTIME_RESOLUTION) + else: + chunk_size = 4096 # Arbitrary value + timer_start = time.time() speech_state = SpeechState.no_speech voice_found, not_voice = 0, 0 @@ -91,15 +104,16 @@ async def receive_audio(socket1, socket2): # os.remove(path) # audio_file = open(path, "a") try: + sample_width = 1 if codec == "pcm8" else 2 while websocket_active: data = await websocket.receive_bytes() - # print(len(data)) if codec == 'opus': audio = AudioSegment(data=data, sample_width=2, frame_rate=sample_rate, channels=channels) samples = torch.tensor(audio.get_array_of_samples()).float() / 32768.0 elif codec in ['pcm8', 'pcm16']: dtype = torch.int8 if codec == 'pcm8' else torch.int16 - samples = torch.frombuffer(data, dtype=dtype).float() + writeable_data = bytearray(data) + samples = torch.frombuffer(writeable_data, dtype=dtype).float() samples = samples / (128.0 if codec == 'pcm8' else 32768.0) else: raise ValueError(f"Unsupported codec: {codec}") @@ -109,8 +123,8 @@ async def receive_audio(socket1, socket2): # len(data) = 160, 8khz 16bit -> 2 bytes per sample, 80 samples, needs 256 samples, which is 256*2 bytes if len(audio_buffer) >= window_size_samples * 2: tensor_audio = torch.tensor(list(audio_buffer)) - if is_speech_present(tensor_audio, vad_iterator, window_size_samples): - print('+Detected speech') + if is_speech_present(tensor_audio[len(tensor_audio) - window_size_samples * 2 :], vad_iterator, window_size_samples): + # print('+Detected speech') is_speech_active = True last_speech_time = time.time() elif is_speech_active: @@ -118,16 +132,20 @@ async def receive_audio(socket1, socket2): is_speech_active = False # Clear only happens after the speech timeout audio_buffer.clear() - print('-NO Detected speech') + # print('-NO Detected speech') continue - print('+Detected speech') else: - print('-NO Detected speech') continue elapsed_seconds = time.time() - timer_start if elapsed_seconds > duration or not socket2: - socket1.send(audio_buffer) + if codec == 'pcm8': # DG does not support pcm8 directly + data = convert_pcm8_to_pcm16(data) + databuffer.extend(data) + if len(databuffer) >= chunk_size: + socket1.send(databuffer[:len(databuffer) - len(databuffer) % chunk_size]) + databuffer = databuffer[len(databuffer) - len(databuffer) % chunk_size:] + await asyncio.sleep(REALTIME_RESOLUTION) if socket2: print('Killing socket2') socket2.finish() diff --git a/backend/utils/stt/deepgram_util.py b/backend/utils/stt/deepgram_util.py index 9bf63c6d1..95edc903d 100644 --- a/backend/utils/stt/deepgram_util.py +++ b/backend/utils/stt/deepgram_util.py @@ -11,6 +11,7 @@ import database.notifications as notification_db from utils.plugins import trigger_realtime_integrations +import numpy as np headers = { "Authorization": f"Token {os.getenv('DEEPGRAM_API_KEY')}", @@ -176,3 +177,11 @@ def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, c return dg_connection except Exception as e: raise Exception(f'Could not open socket: {e}') + +def convert_pcm8_to_pcm16(data): + """ + Convert 8-bit PCM to 16-bit PCM. Because Deepgram only supports 16-bit PCM. + """ + audio_as_np_int8 = np.frombuffer(data, dtype=np.uint8) + audio_as_np_int16 = (audio_as_np_int8.astype(np.int16) - 128) * 256 + return audio_as_np_int16.tobytes() \ No newline at end of file diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index 54dc910c0..c9b4072ac 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -35,33 +35,51 @@ def get_speech_state(data, vad_iterator, window_size_samples=256): # maybe like, if `end` was last, then return end? TEST THIS if speech_dict: + vad_iterator.reset_states() # print(speech_dict) - if 'start' in speech_dict: - has_start = True - elif 'end' in speech_dict: - has_end = True - # print('----') - if has_start: - return SpeechState.has_speech - elif has_end: - return SpeechState.no_speech - return None - - # for i in range(0, len(data), window_size_samples): - # chunk = data[i: i + window_size_samples] - # if len(chunk) < window_size_samples: - # break - # speech_dict = vad_iterator(chunk, return_seconds=False) - # if speech_dict: - # print(speech_dict) - # # how many times this triggers? - # if 'start' in speech_dict: - # return SpeechState.has_speech - # elif 'end' in speech_dict: - # return SpeechState.no_speech - # return None + return True + vad_iterator.reset_states() + return False +def voice_in_bytes(data): + # Convert audio bytes to a numpy array + audio_array = np.frombuffer(data, dtype=np.int16) + + # Normalize audio to range [-1, 1] + audio_tensor = torch.from_numpy(audio_array).float() / 32768.0 + + # Ensure the audio is in the correct shape (batch_size, num_channels, num_samples) + audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) + + # Pass the audio tensor to the VAD model + speech_timestamps = get_speech_timestamps(audio_tensor, model) + + # Check if there's voice in the audio + if speech_timestamps: + print("Voice detected in the audio.") + else: + print("No voice detected in the audio.") + + +# +# +# def speech_probabilities(file_path): +# SAMPLING_RATE = 8000 +# vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE) +# wav = read_audio(file_path, sampling_rate=SAMPLING_RATE) +# speech_probs = [] +# window_size_samples = 512 if SAMPLING_RATE == 16000 else 256 +# for i in range(0, len(wav), window_size_samples): +# chunk = wav[i: i + window_size_samples] +# if len(chunk) < window_size_samples: +# break +# speech_prob = model(chunk, SAMPLING_RATE).item() +# speech_probs.append(speech_prob) +# vad_iterator.reset_states() # reset model states after each audio +# print(speech_probs[:10]) # first 10 chunks predicts +# +# @timeit def is_audio_empty(file_path, sample_rate=8000): wav = read_audio(file_path) From 2bef541fb21125ec2c0a93d12fcb67df79fda8fd Mon Sep 17 00:00:00 2001 From: 0xzre Date: Mon, 12 Aug 2024 05:58:35 +0700 Subject: [PATCH 04/13] fix opus codec decoding --- backend/routers/transcribe.py | 17 +++++++++++------ backend/utils/stt/deepgram_util.py | 1 + 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 8a05a5124..07eff1166 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -9,6 +9,7 @@ from starlette.websockets import WebSocketState import torch from collections import deque +import opuslib from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram, convert_pcm8_to_pcm16 @@ -79,17 +80,19 @@ async def _websocket_util( threshold = 0.6 # Currently most fitting threshold vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 + if codec == 'opus': + decoder = opuslib.Decoder(sample_rate, channels) async def receive_audio(socket1, socket2): - nonlocal is_speech_active, last_speech_time, websocket_active + nonlocal is_speech_active, last_speech_time, decoder, websocket_active audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs databuffer = bytearray(b"") REALTIME_RESOLUTION = 0.01 - if codec == 'opus': - sample_width = 2 - else: + if codec == 'pcm8': sample_width = 1 + else: + sample_width = 2 if sample_width: byte_rate = sample_width * sample_rate * channels chunk_size = int(byte_rate * REALTIME_RESOLUTION) @@ -108,8 +111,10 @@ async def receive_audio(socket1, socket2): while websocket_active: data = await websocket.receive_bytes() if codec == 'opus': - audio = AudioSegment(data=data, sample_width=2, frame_rate=sample_rate, channels=channels) - samples = torch.tensor(audio.get_array_of_samples()).float() / 32768.0 + data = decoder.decode(data, frame_size=320) # 160 if want lower latency + # audio = AudioSegment(data=data, sample_width=sample_width, frame_rate=sample_rate, channels=channels, format='opus') + # samples = torch.tensor(audio.get_array_of_samples()).float() / 32768.0 + samples = torch.frombuffer(data, dtype=torch.int16).float() / 32768.0 elif codec in ['pcm8', 'pcm16']: dtype = torch.int8 if codec == 'pcm8' else torch.int16 writeable_data = bytearray(data) diff --git a/backend/utils/stt/deepgram_util.py b/backend/utils/stt/deepgram_util.py index 95edc903d..8b228c2fc 100644 --- a/backend/utils/stt/deepgram_util.py +++ b/backend/utils/stt/deepgram_util.py @@ -104,6 +104,7 @@ async def process_audio_dg( def on_message(self, result, **kwargs): # print(f"Received message from Deepgram") # Log when message is received sentence = result.channel.alternatives[0].transcript + print(sentence) if len(sentence) == 0: return # print(sentence) From c758076619a785d850091d53fdd77c83db20173f Mon Sep 17 00:00:00 2001 From: 0xzre Date: Fri, 16 Aug 2024 09:39:52 +0700 Subject: [PATCH 05/13] fix pcm8 and change threshold --- backend/routers/transcribe.py | 42 ++++++++++-------------------- backend/utils/stt/deepgram_util.py | 8 ------ 2 files changed, 14 insertions(+), 36 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 07eff1166..2601133d1 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -12,7 +12,7 @@ import opuslib from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration -from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram, convert_pcm8_to_pcm16 +from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram, convert_audio_bytes_to_resampled_bytes from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty, is_speech_present router = APIRouter() @@ -54,7 +54,7 @@ async def _websocket_util( websocket_active = True duration = 0 is_speech_active = False - speech_timeout = 0.7 # Configurable even better if we can decade from user needs/behaviour + speech_timeout = 0.7 last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: @@ -77,7 +77,7 @@ async def _websocket_util( await websocket.close() return - threshold = 0.6 # Currently most fitting threshold + threshold = 0.7 vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 if codec == 'opus': @@ -89,15 +89,9 @@ async def receive_audio(socket1, socket2): databuffer = bytearray(b"") REALTIME_RESOLUTION = 0.01 - if codec == 'pcm8': - sample_width = 1 - else: - sample_width = 2 - if sample_width: - byte_rate = sample_width * sample_rate * channels - chunk_size = int(byte_rate * REALTIME_RESOLUTION) - else: - chunk_size = 4096 # Arbitrary value + sample_width = 2 # pcm here is 16 bit + byte_rate = sample_width * sample_rate * channels + chunk_size = int(byte_rate * REALTIME_RESOLUTION) timer_start = time.time() speech_state = SpeechState.no_speech @@ -111,25 +105,19 @@ async def receive_audio(socket1, socket2): while websocket_active: data = await websocket.receive_bytes() if codec == 'opus': - data = decoder.decode(data, frame_size=320) # 160 if want lower latency - # audio = AudioSegment(data=data, sample_width=sample_width, frame_rate=sample_rate, channels=channels, format='opus') - # samples = torch.tensor(audio.get_array_of_samples()).float() / 32768.0 + decoded_opus = decoder.decode(data, frame_size=320) + samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 + elif codec in ['pcm8', 'pcm16']: # Both are now 16-bit samples = torch.frombuffer(data, dtype=torch.int16).float() / 32768.0 - elif codec in ['pcm8', 'pcm16']: - dtype = torch.int8 if codec == 'pcm8' else torch.int16 - writeable_data = bytearray(data) - samples = torch.frombuffer(writeable_data, dtype=dtype).float() - samples = samples / (128.0 if codec == 'pcm8' else 32768.0) else: raise ValueError(f"Unsupported codec: {codec}") audio_buffer.extend(samples) - # print(len(audio_buffer), window_size_samples * 2) # * 2 because 16bit - # len(data) = 160, 8khz 16bit -> 2 bytes per sample, 80 samples, needs 256 samples, which is 256*2 bytes - if len(audio_buffer) >= window_size_samples * 2: + if len(audio_buffer) >= window_size_samples: tensor_audio = torch.tensor(list(audio_buffer)) - if is_speech_present(tensor_audio[len(tensor_audio) - window_size_samples * 2 :], vad_iterator, window_size_samples): - # print('+Detected speech') + if is_speech_present(tensor_audio[-window_size_samples:], vad_iterator, window_size_samples): + if not is_speech_active: + print('+Detected speech') is_speech_active = True last_speech_time = time.time() elif is_speech_active: @@ -137,15 +125,13 @@ async def receive_audio(socket1, socket2): is_speech_active = False # Clear only happens after the speech timeout audio_buffer.clear() - # print('-NO Detected speech') + print('-NO Detected speech') continue else: continue elapsed_seconds = time.time() - timer_start if elapsed_seconds > duration or not socket2: - if codec == 'pcm8': # DG does not support pcm8 directly - data = convert_pcm8_to_pcm16(data) databuffer.extend(data) if len(databuffer) >= chunk_size: socket1.send(databuffer[:len(databuffer) - len(databuffer) % chunk_size]) diff --git a/backend/utils/stt/deepgram_util.py b/backend/utils/stt/deepgram_util.py index 8b228c2fc..2bc94510d 100644 --- a/backend/utils/stt/deepgram_util.py +++ b/backend/utils/stt/deepgram_util.py @@ -178,11 +178,3 @@ def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, c return dg_connection except Exception as e: raise Exception(f'Could not open socket: {e}') - -def convert_pcm8_to_pcm16(data): - """ - Convert 8-bit PCM to 16-bit PCM. Because Deepgram only supports 16-bit PCM. - """ - audio_as_np_int8 = np.frombuffer(data, dtype=np.uint8) - audio_as_np_int16 = (audio_as_np_int8.astype(np.int16) - 128) * 256 - return audio_as_np_int16.tobytes() \ No newline at end of file From 91c8d5d2534f0c021fa5b44ddb6b5c91b37b5266 Mon Sep 17 00:00:00 2001 From: 0xzre Date: Fri, 16 Aug 2024 17:54:31 +0700 Subject: [PATCH 06/13] tested --- backend/routers/transcribe.py | 6 +++--- backend/utils/stt/vad.py | 5 ----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 2601133d1..c3ae702d9 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -12,7 +12,7 @@ import opuslib from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration -from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram, convert_audio_bytes_to_resampled_bytes +from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty, is_speech_present router = APIRouter() @@ -54,7 +54,7 @@ async def _websocket_util( websocket_active = True duration = 0 is_speech_active = False - speech_timeout = 0.7 + speech_timeout = 0.7 # Good for now but better dynamically adjust it by user behaviour last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: @@ -77,7 +77,7 @@ async def _websocket_util( await websocket.close() return - threshold = 0.7 + threshold = 0.6 vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 if codec == 'opus': diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index c9b4072ac..04512d32d 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -19,11 +19,6 @@ (get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils -class SpeechState(str, Enum): - has_speech = 'has_speech' - no_speech = 'no_speech' - - def get_speech_state(data, vad_iterator, window_size_samples=256): has_start, has_end = False, False for i in range(0, len(data), window_size_samples): From 70d176783b83734894ea075cbce11bfacfecf0dd Mon Sep 17 00:00:00 2001 From: 0xzre Date: Sat, 17 Aug 2024 21:10:27 +0700 Subject: [PATCH 07/13] fix some function to match working one --- backend/routers/transcribe.py | 20 ++++++-------------- backend/utils/stt/vad.py | 3 +-- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index c3ae702d9..af0d01121 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -13,7 +13,7 @@ from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram -from utils.stt.vad import VADIterator, model, get_speech_state, SpeechState, vad_is_empty, is_speech_present +from utils.stt.vad import VADIterator, model, vad_is_empty, is_speech_present router = APIRouter() @@ -54,7 +54,7 @@ async def _websocket_util( websocket_active = True duration = 0 is_speech_active = False - speech_timeout = 0.7 # Good for now but better dynamically adjust it by user behaviour + speech_timeout = 0.7 # Good for now but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: @@ -89,26 +89,20 @@ async def receive_audio(socket1, socket2): databuffer = bytearray(b"") REALTIME_RESOLUTION = 0.01 - sample_width = 2 # pcm here is 16 bit + sample_width = 2 # pcm8/16 here is 16 bit byte_rate = sample_width * sample_rate * channels chunk_size = int(byte_rate * REALTIME_RESOLUTION) timer_start = time.time() - speech_state = SpeechState.no_speech - voice_found, not_voice = 0, 0 - # path = 'scripts/vad/audio_bytes.txt' - # if os.path.exists(path): - # os.remove(path) - # audio_file = open(path, "a") try: - sample_width = 1 if codec == "pcm8" else 2 while websocket_active: data = await websocket.receive_bytes() if codec == 'opus': decoded_opus = decoder.decode(data, frame_size=320) samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 - elif codec in ['pcm8', 'pcm16']: # Both are now 16-bit - samples = torch.frombuffer(data, dtype=torch.int16).float() / 32768.0 + elif codec in ['pcm8', 'pcm16']: # Both are 16 bit + writable_data = bytearray(data) + samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0 else: raise ValueError(f"Unsupported codec: {codec}") @@ -144,8 +138,6 @@ async def receive_audio(socket1, socket2): else: socket2.send(audio_buffer) - audio_buffer = bytearray() - except WebSocketDisconnect: print("WebSocket disconnected") except Exception as e: diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index 04512d32d..8fdf9e56f 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -19,8 +19,7 @@ (get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils -def get_speech_state(data, vad_iterator, window_size_samples=256): - has_start, has_end = False, False +def is_speech_present(data, vad_iterator, window_size_samples=256): for i in range(0, len(data), window_size_samples): chunk = data[i: i + window_size_samples] if len(chunk) < window_size_samples: From a6b8b811617396cc90bb9b2b7a4141379c498f74 Mon Sep 17 00:00:00 2001 From: 0xzre Date: Sun, 18 Aug 2024 15:32:16 +0700 Subject: [PATCH 08/13] fix WS transcribe: reset state, data length sent to DG, sleep logic, less agressive timeout --- backend/routers/transcribe.py | 22 ++++++++++++++++------ backend/utils/stt/vad.py | 2 -- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index af0d01121..0536a398b 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -54,7 +54,7 @@ async def _websocket_util( websocket_active = True duration = 0 is_speech_active = False - speech_timeout = 0.7 # Good for now but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. + speech_timeout = 1.0 # Good for now (who doesnt like integer) but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: @@ -94,6 +94,7 @@ async def receive_audio(socket1, socket2): chunk_size = int(byte_rate * REALTIME_RESOLUTION) timer_start = time.time() + audio_cursor = 0 # For sleep realtime logic try: while websocket_active: data = await websocket.receive_bytes() @@ -109,6 +110,7 @@ async def receive_audio(socket1, socket2): audio_buffer.extend(samples) if len(audio_buffer) >= window_size_samples: tensor_audio = torch.tensor(list(audio_buffer)) + # Good alr, but increase the window size to get wider context but server will be slower if is_speech_present(tensor_audio[-window_size_samples:], vad_iterator, window_size_samples): if not is_speech_active: print('+Detected speech') @@ -117,8 +119,9 @@ async def receive_audio(socket1, socket2): elif is_speech_active: if time.time() - last_speech_time > speech_timeout: is_speech_active = False - # Clear only happens after the speech timeout - audio_buffer.clear() + # Reset only happens after the speech timeout + # Reason : Better to carry vad context for a speech, then reset for any new speech + vad_iterator.reset_states() print('-NO Detected speech') continue else: @@ -128,9 +131,16 @@ async def receive_audio(socket1, socket2): if elapsed_seconds > duration or not socket2: databuffer.extend(data) if len(databuffer) >= chunk_size: - socket1.send(databuffer[:len(databuffer) - len(databuffer) % chunk_size]) - databuffer = databuffer[len(databuffer) - len(databuffer) % chunk_size:] - await asyncio.sleep(REALTIME_RESOLUTION) + # Sleep logic, because naive sleep is not accurate + current_time = time.time() + elapsed_time = current_time - timer_start + if elapsed_time < audio_cursor + REALTIME_RESOLUTION: + sleep_time = (audio_cursor + REALTIME_RESOLUTION) - elapsed_time + await asyncio.sleep(sleep_time) + # Just send them all, no difference + socket1.send(databuffer) + databuffer = bytearray(b"") + audio_cursor += REALTIME_RESOLUTION if socket2: print('Killing socket2') socket2.finish() diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index 8fdf9e56f..dc45c4495 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -29,10 +29,8 @@ def is_speech_present(data, vad_iterator, window_size_samples=256): # maybe like, if `end` was last, then return end? TEST THIS if speech_dict: - vad_iterator.reset_states() # print(speech_dict) return True - vad_iterator.reset_states() return False From 3f40a43080228cf9cd6e194c2f614c13245c81aa Mon Sep 17 00:00:00 2001 From: 0xzre Date: Sun, 18 Aug 2024 15:46:12 +0700 Subject: [PATCH 09/13] Merge remote-tracking branch 'origin' into ws-vad-fix --- .gitignore | 3 + app/android/key.properties | 4 - app/lib/backend/http/api/memories.dart | 29 +- app/lib/backend/preferences.dart | 11 +- app/lib/env/env.dart | 4 +- app/lib/pages/capture/page.dart | 43 +- app/lib/pages/capture/widgets/widgets.dart | 4 +- app/lib/pages/home/device.dart | 5 +- app/lib/pages/home/device_settings.dart | 2 +- app/lib/pages/home/page.dart | 39 +- .../onboarding/find_device/found_devices.dart | 39 +- .../pages/onboarding/find_device/page.dart | 9 +- app/lib/utils/ble/connected.dart | 2 +- app/lib/utils/ble/scan.dart | 4 +- backend/database/memories.py | 29 +- .../redis_utils.py => database/redis_db.py} | 0 backend/database/{vector.py => vector_db.py} | 0 backend/main.py | 4 +- backend/models/__init__.py | 0 backend/models/memory.py | 27 +- backend/requirements.txt | 57 +++ backend/routers/__init__.py | 0 backend/routers/chat.py | 4 +- backend/routers/memories.py | 402 +++++------------- backend/routers/notifications.py | 4 +- backend/routers/plugins.py | 4 +- backend/routers/screenpipe.py | 2 +- backend/routers/speech_profile.py | 84 +--- backend/routers/transcribe.py | 43 +- backend/routers/workflow.py | 2 +- backend/scripts/stt/h_brainstorming.py | 2 +- backend/utils/{ => @deprecated}/preprocess.py | 0 .../utils/{stt => @deprecated}/soniox_util.py | 0 .../{ => @deprecated}/speaker_profile.py | 0 backend/utils/{stt => @deprecated}/whisper.py | 0 .../utils/{stt => @deprecated}/whisper_x.py | 0 backend/utils/__init__.py | 0 backend/utils/auth.py | 26 -- backend/utils/llm.py | 261 ++++-------- backend/utils/{ => memories}/location.py | 0 .../utils/{ => memories}/process_memory.py | 74 ++-- backend/utils/{ => other}/endpoints.py | 28 +- .../utils/{crons => other}/notifications.py | 0 backend/utils/other/storage.py | 37 ++ backend/utils/plugins.py | 2 +- backend/utils/prompt.py | 60 --- backend/utils/{ => retrieval}/rag.py | 2 +- backend/utils/storage.py | 81 ---- backend/utils/stt/{fal.py => pre_recorded.py} | 38 +- .../stt/{deepgram_util.py => streaming.py} | 104 ++--- backend/utils/stt/vad.py | 80 +--- 51 files changed, 674 insertions(+), 981 deletions(-) delete mode 100644 app/android/key.properties rename backend/{utils/redis_utils.py => database/redis_db.py} (100%) rename backend/database/{vector.py => vector_db.py} (100%) create mode 100644 backend/models/__init__.py create mode 100644 backend/routers/__init__.py rename backend/utils/{ => @deprecated}/preprocess.py (100%) rename backend/utils/{stt => @deprecated}/soniox_util.py (100%) rename backend/utils/{ => @deprecated}/speaker_profile.py (100%) rename backend/utils/{stt => @deprecated}/whisper.py (100%) rename backend/utils/{stt => @deprecated}/whisper_x.py (100%) create mode 100644 backend/utils/__init__.py delete mode 100644 backend/utils/auth.py rename backend/utils/{ => memories}/location.py (100%) rename backend/utils/{ => memories}/process_memory.py (51%) rename backend/utils/{ => other}/endpoints.py (64%) rename backend/utils/{crons => other}/notifications.py (100%) create mode 100644 backend/utils/other/storage.py delete mode 100644 backend/utils/prompt.py rename backend/utils/{ => retrieval}/rag.py (96%) delete mode 100644 backend/utils/storage.py rename backend/utils/stt/{fal.py => pre_recorded.py} (64%) rename backend/utils/stt/{deepgram_util.py => streaming.py} (68%) diff --git a/.gitignore b/.gitignore index d1bb71343..c650ed5ee 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ dist/ **/android/.gradle **/android/gradlew **/android/gradlew.bat +**/android/key.properties **/android/**/gradle-wrapper.jar **/android/**/local.properties **/android/**/GeneratedPluginRegistrant.java @@ -131,6 +132,8 @@ plugins/example/tmp/ backend/pretrained_models/ /backend/scripts/data/ /backend/pretrained_models/ +*/pretrained_models/ +/backend/_speech_profiles backend/google-credentials.json backend/google-credentials-dev.json diff --git a/app/android/key.properties b/app/android/key.properties deleted file mode 100644 index 54e211178..000000000 --- a/app/android/key.properties +++ /dev/null @@ -1,4 +0,0 @@ -storePassword=x5w4ypvs -keyPassword=x5w4ypvs -keyAlias=upload -storeFile=upload-keystore.jks \ No newline at end of file diff --git a/app/lib/backend/http/api/memories.dart b/app/lib/backend/http/api/memories.dart index 2e1624ca1..d6d73d663 100644 --- a/app/lib/backend/http/api/memories.dart +++ b/app/lib/backend/http/api/memories.dart @@ -9,7 +9,9 @@ import 'package:friend_private/backend/http/shared.dart'; import 'package:friend_private/backend/http/webhooks.dart'; import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/env/env.dart'; +import 'package:http/http.dart' as http; import 'package:instabug_flutter/instabug_flutter.dart'; +import 'package:path/path.dart'; import 'package:tuple/tuple.dart'; Future migrateMemoriesToBackend(List memories) async { @@ -45,7 +47,7 @@ Future createMemoryServer({ 'photos': photos.map((photo) => {'base64': photo.item1, 'description': photo.item2}).toList(), 'source': transcriptSegments.isNotEmpty ? 'friend' : 'openglass', 'language': language, // maybe determine auto? - 'audio_base64_url': audioFile != null ? await wavToBase64Url(audioFile.path) : null, + // 'audio_base64_url': audioFile != null ? await wavToBase64Url(audioFile.path) : null, }), ); if (response == null) return null; @@ -67,6 +69,31 @@ Future createMemoryServer({ return null; } +Future memoryPostProcessing(File file, String memoryId) async { + var request = http.MultipartRequest( + 'POST', + Uri.parse('${Env.apiBaseUrl}v1/memories/$memoryId/post-processing'), + ); + request.files.add(await http.MultipartFile.fromPath('file', file.path, filename: basename(file.path))); + request.headers.addAll({'Authorization': await getAuthHeader()}); + + try { + var streamedResponse = await request.send(); + var response = await http.Response.fromStream(streamedResponse); + + if (response.statusCode == 200) { + debugPrint('memoryPostProcessing Response body: ${jsonDecode(response.body)}'); + return ServerMemory.fromJson(jsonDecode(response.body)); + } else { + debugPrint('Failed to memoryPostProcessing. Status code: ${response.statusCode}'); + throw Exception('Failed to memoryPostProcessing. Status code: ${response.statusCode}'); + } + } catch (e) { + debugPrint('An error occurred memoryPostProcessing: $e'); + throw Exception('An error occurred memoryPostProcessing: $e'); + } +} + Future> getMemories({int limit = 50, int offset = 0}) async { var response = await makeApiCall( url: '${Env.apiBaseUrl}v1/memories?limit=$limit&offset=$offset', headers: {}, method: 'GET', body: ''); diff --git a/app/lib/backend/preferences.dart b/app/lib/backend/preferences.dart index a3573b369..3893ab556 100644 --- a/app/lib/backend/preferences.dart +++ b/app/lib/backend/preferences.dart @@ -26,10 +26,15 @@ class SharedPreferencesUtil { String get uid => getString('uid') ?? ''; - // TODO: store device object rather - set deviceId(String value) => saveString('deviceId', value); + set btDeviceStruct(BTDeviceStruct value) { + saveString('btDeviceStruct', jsonEncode(value.toJson())); + } - String get deviceId => getString('deviceId') ?? ''; + BTDeviceStruct get btDeviceStruct { + final String device = getString('btDeviceStruct') ?? ''; + if (device.isEmpty) return BTDeviceStruct(id: '', name: ''); + return BTDeviceStruct.fromJson(jsonDecode(device)); + } set deviceName(String value) => saveString('deviceName', value); diff --git a/app/lib/env/env.dart b/app/lib/env/env.dart index 9c5af6328..d174da5c6 100644 --- a/app/lib/env/env.dart +++ b/app/lib/env/env.dart @@ -15,9 +15,9 @@ abstract class Env { static String? get mixpanelProjectToken => _instance.mixpanelProjectToken; - // static String? get apiBaseUrl => _instance.apiBaseUrl; + static String? get apiBaseUrl => _instance.apiBaseUrl; - static String? get apiBaseUrl => 'https://camel-lucky-reliably.ngrok-free.app/'; + // static String? get apiBaseUrl => 'https://camel-lucky-reliably.ngrok-free.app/'; static String? get growthbookApiKey => _instance.growthbookApiKey; diff --git a/app/lib/pages/capture/page.dart b/app/lib/pages/capture/page.dart index 926f5e52a..83e1e23e1 100644 --- a/app/lib/pages/capture/page.dart +++ b/app/lib/pages/capture/page.dart @@ -1,6 +1,7 @@ import 'dart:async'; import 'dart:convert'; import 'dart:io'; +import 'dart:math'; import 'package:flutter/material.dart'; import 'package:flutter/scheduler.dart'; @@ -8,6 +9,7 @@ import 'package:flutter_foreground_task/flutter_foreground_task.dart'; import 'package:friend_private/backend/database/geolocation.dart'; import 'package:friend_private/backend/database/memory.dart'; import 'package:friend_private/backend/database/transcript_segment.dart'; +import 'package:friend_private/backend/http/api/memories.dart'; import 'package:friend_private/backend/http/cloud_storage.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; @@ -35,6 +37,7 @@ import 'logic/websocket_mixin.dart'; class CapturePage extends StatefulWidget { final Function addMemory; final Function addMessage; + final Function(ServerMemory) updateMemory; final BTDeviceStruct? device; const CapturePage({ @@ -42,6 +45,7 @@ class CapturePage extends StatefulWidget { required this.device, required this.addMemory, required this.addMessage, + required this.updateMemory, }); @override @@ -132,16 +136,24 @@ class CapturePageState extends State streamStartedAtSecond = null; secondsMissedOnReconnect = (DateTime.now().difference(firstStreamReceivedAt!).inSeconds); } - setState(() {}); + if (mounted) { + setState(() {}); + } + }, + onConnectionFailed: (err) { + if (mounted) { + setState(() {}); + } }, - onConnectionFailed: (err) => setState(() {}), onConnectionClosed: (int? closeCode, String? closeReason) { // connection was closed, either on resetState, or by backend, or by some other reason. // setState(() {}); }, onConnectionError: (err) { // connection was okay, but then failed. - setState(() {}); + if (mounted) { + setState(() {}); + } }, onMessageReceived: (List newSegments) { if (newSegments.isEmpty) return; @@ -150,7 +162,9 @@ class CapturePageState extends State // TODO: small bug -> when memory A creates, and memory B starts, memory B will clean a lot more seconds than available, // losing from the audio the first part of the recording. All other parts are fine. FlutterForegroundTask.sendDataToTask(jsonEncode({'location': true})); - audioStorage?.removeFramesRange(fromSecond: 0, toSecond: newSegments[0].start.toInt()); + var currentSeconds = (audioStorage?.frames.length ?? 0) ~/ 100; + var removeUpToSecond = newSegments[0].start.toInt(); + audioStorage?.removeFramesRange(fromSecond: 0, toSecond: min(max(currentSeconds - 5, 0), removeUpToSecond)); firstStreamReceivedAt = DateTime.now(); } streamStartedAtSecond ??= newSegments[0].start; @@ -289,16 +303,24 @@ class CapturePageState extends State language: segments.isNotEmpty ? SharedPreferencesUtil().recordingsLanguage : null, ); SharedPreferencesUtil().addFailedMemory(memory); - ScaffoldMessenger.of(context).showSnackBar(const SnackBar( - content: Text( - 'Memory creation failed. It\' stored locally and will be retried soon.', - style: TextStyle(color: Colors.white, fontSize: 14), - ), - )); + if (mounted) { + ScaffoldMessenger.of(context).showSnackBar(const SnackBar( + content: Text( + 'Memory creation failed. It\' stored locally and will be retried soon.', + style: TextStyle(color: Colors.white, fontSize: 14), + ), + )); + } + // TODO: store anyways something temporal and retry once connected again. } if (memory != null) widget.addMemory(memory); + if (memory != null && !memory.failed && file != null && segments.isNotEmpty && !memory.discarded) { + memoryPostProcessing(file, memory.id).then((postProcessed) { + widget.updateMemory(postProcessed); + }); + } SharedPreferencesUtil().transcriptSegments = []; segments = []; @@ -403,7 +425,6 @@ class CapturePageState extends State void dispose() { WidgetsBinding.instance.removeObserver(this); record.dispose(); - _bleBytesStream?.cancel(); _memoryCreationTimer?.cancel(); _internetListener.cancel(); diff --git a/app/lib/pages/capture/widgets/widgets.dart b/app/lib/pages/capture/widgets/widgets.dart index d69709fee..30f867d8a 100644 --- a/app/lib/pages/capture/widgets/widgets.dart +++ b/app/lib/pages/capture/widgets/widgets.dart @@ -30,7 +30,7 @@ getConnectionStateWidgets( if (device == null) { return [ const DeviceAnimationWidget(sizeMultiplier: 0.7), - SharedPreferencesUtil().deviceId.isEmpty + SharedPreferencesUtil().btDeviceStruct.id == '' ? _getNoFriendConnectedYet(context) : const ScanningUI( string1: 'Looking for Friend wearable', @@ -335,7 +335,7 @@ connectionStatusWidgets( } getPhoneMicRecordingButton(VoidCallback recordingToggled, RecordingState state) { - if (SharedPreferencesUtil().deviceId.isNotEmpty) return const SizedBox.shrink(); + if (SharedPreferencesUtil().btDeviceStruct.id.isNotEmpty) return const SizedBox.shrink(); return Visibility( visible: true, child: Padding( diff --git a/app/lib/pages/home/device.dart b/app/lib/pages/home/device.dart index 334436750..b25e3a073 100644 --- a/app/lib/pages/home/device.dart +++ b/app/lib/pages/home/device.dart @@ -20,7 +20,6 @@ class ConnectedDevice extends StatefulWidget { State createState() => _ConnectedDeviceState(); } - class _ConnectedDeviceState extends State { @override Widget build(BuildContext context) { @@ -61,7 +60,7 @@ class _ConnectedDeviceState extends State { mainAxisAlignment: MainAxisAlignment.start, children: [ Text( - '$deviceName (${widget.device?.getShortId() ?? ''})', + '$deviceName (${widget.device?.getShortId() ?? SharedPreferencesUtil().btDeviceStruct.getShortId()})', style: const TextStyle( color: Colors.white, fontSize: 16.0, @@ -154,7 +153,7 @@ class _ConnectedDeviceState extends State { await bleDisconnectDevice(widget.device!); } Navigator.of(context).pop(); - SharedPreferencesUtil().deviceId = ''; + SharedPreferencesUtil().btDeviceStruct = BTDeviceStruct(id: '', name: ''); SharedPreferencesUtil().deviceName = ''; MixpanelManager().disconnectFriendClicked(); }, diff --git a/app/lib/pages/home/device_settings.dart b/app/lib/pages/home/device_settings.dart index 4fbbb5bd5..36002e549 100644 --- a/app/lib/pages/home/device_settings.dart +++ b/app/lib/pages/home/device_settings.dart @@ -106,7 +106,7 @@ class DeviceSettings extends StatelessWidget { child: TextButton( onPressed: () { if (device != null) bleDisconnectDevice(device!); - SharedPreferencesUtil().deviceId = ''; + SharedPreferencesUtil().btDeviceStruct = BTDeviceStruct(id: '', name: ''); SharedPreferencesUtil().deviceName = ''; Navigator.of(context).pop(); Navigator.of(context).pop(); diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index b555740f1..aec885474 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -229,7 +229,7 @@ class _HomePageWrapperState extends State with WidgetsBindingOb _setupHasSpeakerProfile(); _migrationScripts(); authenticateGCP(); - if (SharedPreferencesUtil().deviceId.isNotEmpty) { + if (SharedPreferencesUtil().btDeviceStruct.id.isNotEmpty) { scanAndConnectDevice().then(_onConnected); } @@ -291,7 +291,7 @@ class _HomePageWrapperState extends State with WidgetsBindingOb _initiateBleBatteryListener(); capturePageKey.currentState?.resetState(restartBytesProcessing: true, btDevice: connectedDevice); MixpanelManager().deviceConnected(); - SharedPreferencesUtil().deviceId = _device!.id; + SharedPreferencesUtil().btDeviceStruct = _device!; SharedPreferencesUtil().deviceName = _device!.name; setState(() {}); } @@ -416,19 +416,26 @@ class _HomePageWrapperState extends State with WidgetsBindingOb textFieldFocusNode: memoriesTextFieldFocusNode, ), CapturePage( - key: capturePageKey, - device: _device, - addMemory: (ServerMemory memory) { - var memoriesCopy = List.from(memories); - memoriesCopy.insert(0, memory); - setState(() => memories = memoriesCopy); - }, - addMessage: (ServerMessage message) { - var messagesCopy = List.from(messages); - messagesCopy.insert(0, message); - setState(() => messages = messagesCopy); - }, - ), + key: capturePageKey, + device: _device, + addMemory: (ServerMemory memory) { + var memoriesCopy = List.from(memories); + memoriesCopy.insert(0, memory); + setState(() => memories = memoriesCopy); + }, + addMessage: (ServerMessage message) { + var messagesCopy = List.from(messages); + messagesCopy.insert(0, message); + setState(() => messages = messagesCopy); + }, + updateMemory: (ServerMemory memory) { + var memoriesCopy = List.from(memories); + var index = memoriesCopy.indexWhere((m) => m.id == memory.id); + if (index != -1) { + memoriesCopy[index] = memory; + setState(() => memories = memoriesCopy); + } + }), ChatPage( key: chatPageKey, textFieldFocusNode: chatTextFieldFocusNode, @@ -612,7 +619,7 @@ class _HomePageWrapperState extends State with WidgetsBindingOb ) : TextButton( onPressed: () async { - if (SharedPreferencesUtil().deviceId.isEmpty) { + if (SharedPreferencesUtil().btDeviceStruct.id.isEmpty) { routeToPage(context, const ConnectDevicePage()); MixpanelManager().connectFriendClicked(); } else { diff --git a/app/lib/pages/onboarding/find_device/found_devices.dart b/app/lib/pages/onboarding/find_device/found_devices.dart index d5cf61499..5f23ac4fd 100644 --- a/app/lib/pages/onboarding/find_device/found_devices.dart +++ b/app/lib/pages/onboarding/find_device/found_devices.dart @@ -6,6 +6,7 @@ import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/utils/ble/communication.dart'; import 'package:friend_private/utils/ble/connect.dart'; +import 'package:friend_private/utils/ble/connected.dart'; import 'package:gradient_borders/gradient_borders.dart'; class FoundDevices extends StatefulWidget { @@ -30,6 +31,8 @@ class _FoundDevicesState extends State { String deviceId = ''; String? _connectingToDeviceId; + Timer? connectionStateTimer; + // TODO: improve this and find_device page. // TODO: include speech profile, once it's well tested, in a few days, rn current version works @@ -43,7 +46,7 @@ class _FoundDevicesState extends State { _connectingToDeviceId = null; // Reset the connecting device }); await Future.delayed(const Duration(seconds: 2)); - SharedPreferencesUtil().deviceId = btDevice.id; + SharedPreferencesUtil().btDeviceStruct = btDevice; SharedPreferencesUtil().deviceName = btDevice.name; widget.goNext(); } catch (e) { @@ -69,6 +72,40 @@ class _FoundDevicesState extends State { setBatteryPercentage(device); } + @override + void initState() { + _initiateConnectionListener(); + super.initState(); + } + + _initiateConnectionListener() async { + connectionStateTimer = Timer.periodic(const Duration(seconds: 3), (timer) async { + var connectedDevice = await getConnectedDevice(); + if (connectedDevice != null) { + if (mounted) { + connectionStateTimer?.cancel(); + var battery = await retrieveBatteryLevel(connectedDevice.id); + setState(() { + deviceName = connectedDevice.name; + deviceId = connectedDevice.id; + batteryPercentage = battery; + _isConnected = true; + _isClicked = false; + _connectingToDeviceId = null; + }); + await Future.delayed(const Duration(seconds: 2)); + widget.goNext(); + } + } + }); + } + + @override + void dispose() { + connectionStateTimer?.cancel(); + super.dispose(); + } + @override Widget build(BuildContext context) { return Column( diff --git a/app/lib/pages/onboarding/find_device/page.dart b/app/lib/pages/onboarding/find_device/page.dart index 566164ee3..bc02b46d7 100644 --- a/app/lib/pages/onboarding/find_device/page.dart +++ b/app/lib/pages/onboarding/find_device/page.dart @@ -96,9 +96,12 @@ class _FindDevicesPageState extends State { List orderedDevices = foundDevicesMap.values.toList(); if (orderedDevices.isNotEmpty) { - setState(() { - deviceList = orderedDevices; - }); + if (mounted) { + setState(() { + deviceList = orderedDevices; + }); + } + _didNotMakeItTimer.cancel(); } }); diff --git a/app/lib/utils/ble/connected.dart b/app/lib/utils/ble/connected.dart index 3c199df03..26d699d06 100644 --- a/app/lib/utils/ble/connected.dart +++ b/app/lib/utils/ble/connected.dart @@ -6,7 +6,7 @@ import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; Future getConnectedDevice() async { - var deviceId = SharedPreferencesUtil().deviceId; + var deviceId = SharedPreferencesUtil().btDeviceStruct.id; for (var device in FlutterBluePlus.connectedDevices) { if (device.remoteId.str == deviceId) { return BTDeviceStruct( diff --git a/app/lib/utils/ble/scan.dart b/app/lib/utils/ble/scan.dart index ec1a19c27..daf96d13c 100644 --- a/app/lib/utils/ble/scan.dart +++ b/app/lib/utils/ble/scan.dart @@ -7,7 +7,7 @@ import 'package:friend_private/utils/ble/find.dart'; Future scanAndConnectDevice({bool autoConnect = true, bool timeout = false}) async { print('scanAndConnectDevice'); - var deviceId = SharedPreferencesUtil().deviceId; + var deviceId = SharedPreferencesUtil().btDeviceStruct.id; print('scanAndConnectDevice ${deviceId}'); for (var device in FlutterBluePlus.connectedDevices) { if (device.remoteId.str == deviceId) { @@ -27,7 +27,7 @@ Future scanAndConnectDevice({bool autoConnect = true, bool time // Technically, there should be only one if (deviceId == '') { deviceId = device.id; - SharedPreferencesUtil().deviceId = device.id; + SharedPreferencesUtil().btDeviceStruct = device; SharedPreferencesUtil().deviceName = device.name; } diff --git a/backend/database/memories.py b/backend/database/memories.py index 4ad581dd2..675b19639 100644 --- a/backend/database/memories.py +++ b/backend/database/memories.py @@ -4,7 +4,8 @@ from google.cloud import firestore from google.cloud.firestore_v1 import FieldFilter -from models.memory import MemoryPhoto +from models.memory import MemoryPhoto, PostProcessingStatus, PostProcessingModel +from models.transcript_segment import TranscriptSegment from ._client import db @@ -107,3 +108,29 @@ def get_memory_photos(uid: str, memory_id: str): memory_ref = user_ref.collection('memories').document(memory_id) photos_ref = memory_ref.collection('photos') return [doc.to_dict() for doc in photos_ref.stream()] + + +# POST PROCESSING + +def set_postprocessing_status( + uid: str, memory_id: str, status: PostProcessingStatus, + model: PostProcessingModel = PostProcessingModel.fal_whisperx +): + user_ref = db.collection('users').document(uid) + memory_ref = user_ref.collection('memories').document(memory_id) + memory_ref.update({'postprocessing.status': status, 'postprocessing.model': model}) + + +def store_model_segments_result(uid: str, memory_id: str, model_name: str, segments: List[TranscriptSegment]): + user_ref = db.collection('users').document(uid) + memory_ref = user_ref.collection('memories').document(memory_id) + segments_ref = memory_ref.collection(model_name) + batch = db.batch() + for i, segment in enumerate(segments): + segment_id = str(uuid.uuid4()) + segment_ref = segments_ref.document(segment_id) + batch.set(segment_ref, segment.dict()) + if i >= 400: + batch.commit() + batch = db.batch() + batch.commit() diff --git a/backend/utils/redis_utils.py b/backend/database/redis_db.py similarity index 100% rename from backend/utils/redis_utils.py rename to backend/database/redis_db.py diff --git a/backend/database/vector.py b/backend/database/vector_db.py similarity index 100% rename from backend/database/vector.py rename to backend/database/vector_db.py diff --git a/backend/main.py b/backend/main.py index 096b3c6d2..37016ebda 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,8 +7,8 @@ from fastapi_utilities import repeat_at from modal import Image, App, asgi_app, Secret -from routers import chat, memories, plugins, speech_profile, transcribe, screenpipe, firmware, notifications, workflow -from utils.crons.notifications import start_cron_job +from routers import workflow, chat, firmware, screenpipe, plugins, memories, transcribe, notifications, speech_profile +from utils.other.notifications import start_cron_job if os.environ.get('SERVICE_ACCOUNT_JSON'): service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) diff --git a/backend/models/__init__.py b/backend/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/models/memory.py b/backend/models/memory.py index de443ddb0..d6f0fed7c 100644 --- a/backend/models/memory.py +++ b/backend/models/memory.py @@ -103,6 +103,22 @@ class MemorySource(str, Enum): workflow = 'workflow' +class PostProcessingStatus(str, Enum): + in_progress = 'in_progress' + completed = 'completed' + canceled = 'canceled' + failed = 'failed' + + +class PostProcessingModel(str, Enum): + fal_whisperx = 'fal_whisperx' + + +class MemoryPostProcessing(BaseModel): + status: PostProcessingStatus + model: PostProcessingModel + + class Memory(BaseModel): id: str created_at: datetime @@ -121,6 +137,8 @@ class Memory(BaseModel): external_data: Optional[Dict] = None + postprocessing: Optional[MemoryPostProcessing] = None + discarded: bool = False deleted: bool = False @@ -137,8 +155,8 @@ def memories_to_string(memories: List['Memory']) -> str: result.append(memory_str.strip()) return "\n\n".join(result) - def get_transcript(self) -> str: - return TranscriptSegment.segments_as_string(self.transcript_segments, include_timestamps=True) + def get_transcript(self, include_timestamps: bool) -> str: + return TranscriptSegment.segments_as_string(self.transcript_segments, include_timestamps=include_timestamps) class CreateMemory(BaseModel): @@ -151,10 +169,9 @@ class CreateMemory(BaseModel): source: MemorySource = MemorySource.friend language: Optional[str] = None - audio_base64_url: Optional[str] = None - def get_transcript(self) -> str: - return TranscriptSegment.segments_as_string(self.transcript_segments, include_timestamps=True) + def get_transcript(self, include_timestamps: bool) -> str: + return TranscriptSegment.segments_as_string(self.transcript_segments, include_timestamps=include_timestamps) class WorkflowMemorySource(str, Enum): diff --git a/backend/requirements.txt b/backend/requirements.txt index bc8c6ecbd..09644bc87 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,9 +4,13 @@ aiohappyeyeballs==2.3.4 aiohttp==3.10.1 aiosignal==1.3.1 aiostream==0.5.2 +alembic==1.13.2 annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 anyio==4.4.0 +asteroid-filterbanks==0.4.0 attrs==24.1.0 +audioread==3.0.1 CacheControl==0.14.0 cachetools==5.4.0 certifi==2024.7.4 @@ -14,16 +18,24 @@ cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 click-spinner==0.1.10 +colorlog==6.8.2 +contourpy==1.2.1 croniter==1.4.1 cryptography==43.0.0 +cycler==0.12.1 dataclasses-json==0.6.7 +decorator==5.1.1 deepgram-sdk==3.4.0 deprecation==2.1.0 distro==1.9.0 +docopt==0.6.2 +einops==0.8.0 +fal_client==0.4.1 fastapi==0.112.0 fastapi-utilities==0.2.0 filelock==3.15.4 firebase-admin==6.5.0 +fonttools==4.53.1 frozenlist==1.4.1 fsspec==2024.6.1 google-api-core==2.19.1 @@ -46,11 +58,17 @@ hpack==4.0.0 httpcore==1.0.5 httplib2==0.22.0 httpx==0.27.0 +httpx-sse==0.4.0 +huggingface-hub==0.24.5 hyperframe==6.0.1 +HyperPyYAML==1.2.2 idna==3.7 Jinja2==3.1.4 +joblib==1.4.2 jsonpatch==1.33 jsonpointer==3.0.0 +julius==0.2.7 +kiwisolver==1.4.5 langchain==0.2.12 langchain-community==0.2.11 langchain-core==0.2.28 @@ -59,9 +77,16 @@ langchain-openai==0.1.20 langchain-pinecone==0.1.3 langchain-text-splitters==0.2.2 langsmith==0.1.96 +lazy_loader==0.4 +librosa==0.10.2.post1 +lightning==2.4.0 +lightning-utilities==0.11.6 +llvmlite==0.43.0 +Mako==1.3.5 markdown-it-py==3.0.0 MarkupSafe==2.1.5 marshmallow==3.21.3 +matplotlib==3.9.2 mdurl==0.1.2 modal==0.64.7 mpmath==1.3.0 @@ -69,15 +94,28 @@ msgpack==1.0.8 multidict==6.0.5 mypy-extensions==1.0.0 networkx==3.3 +numba==0.60.0 numpy==1.26.4 +omegaconf==2.3.0 openai==1.39.0 +optuna==3.6.1 orjson==3.10.6 packaging==24.1 +pandas==2.2.2 +pillow==10.4.0 pinecone-client==5.0.1 pinecone-plugin-inference==1.0.3 pinecone-plugin-interface==0.0.7 +platformdirs==4.2.2 +pooch==1.8.2 +primePy==1.3 proto-plus==1.24.0 protobuf==4.25.4 +pyannote.audio==3.3.1 +pyannote.core==5.0.0 +pyannote.database==5.1.0 +pyannote.metrics==3.2.1 +pyannote.pipeline==3.0.1 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycparser==2.22 @@ -90,6 +128,8 @@ pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 python-multipart==0.0.9 +pytorch-lightning==2.4.0 +pytorch-metric-learning==2.6.0 pytz==2024.1 PyYAML==6.0.1 redis==5.0.8 @@ -97,25 +137,42 @@ regex==2024.7.24 requests==2.32.3 rich==13.7.1 rsa==4.9 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.8 +scikit-learn==1.5.1 +scipy==1.14.0 +semver==3.0.2 +sentencepiece==0.2.0 shellingham==1.5.4 sigtools==4.0.1 six==1.16.0 sniffio==1.3.1 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soxr==0.4.0 +speechbrain==1.0.0 SQLAlchemy==2.0.32 starlette==0.37.2 sympy==1.13.1 synchronicity==0.6.7 +tabulate==0.9.0 tenacity==8.5.0 +tensorboardX==2.6.2.2 +threadpoolctl==3.5.0 tiktoken==0.7.0 toml==0.10.2 torch==2.4.0 +torch-audiomentations==0.11.1 +torch-pitch-shift==1.2.4 torchaudio==2.4.0 +torchmetrics==1.4.1 tqdm==4.66.5 typer==0.12.3 types-certifi==2021.10.8.3 types-toml==0.10.8.20240310 typing-inspect==0.9.0 typing_extensions==4.12.2 +tzdata==2024.1 uritemplate==4.1.1 urllib3==2.2.2 uvicorn==0.30.5 diff --git a/backend/routers/__init__.py b/backend/routers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 4cf01147e..c42056b72 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -6,10 +6,10 @@ import database.chat as chat_db from models.chat import Message, SendMessageRequest, MessageSender -from utils import auth +from utils.other import endpoints as auth from utils.llm import qa_rag, initial_chat_message from utils.plugins import get_plugin_by_id -from utils.rag import retrieve_rag_context +from utils.retrieval.rag import retrieve_rag_context router = APIRouter() diff --git a/backend/routers/memories.py b/backend/routers/memories.py index 1f956ba00..ca19de72e 100644 --- a/backend/routers/memories.py +++ b/backend/routers/memories.py @@ -1,155 +1,45 @@ -import hashlib -import asyncio -import random -import threading -import uuid -from typing import Union +import os -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, UploadFile +from pydub import AudioSegment import database.memories as memories_db -from database.vector import delete_vector, upsert_vectors, upsert_vector +from database.vector_db import delete_vector from models.memory import * -from models.integrations import * -from models.plugin import Plugin -from utils.plugins import get_plugins_data from models.transcript_segment import TranscriptSegment -from utils import auth -from utils.llm import generate_embedding, get_transcript_structure, get_plugin_result, summarize_open_glass, summarize_experience_text -from utils.location import get_google_maps_location +from utils.other import endpoints as auth +from utils.llm import transcript_user_speech_fix, num_tokens_from_string +from utils.memories.location import get_google_maps_location from utils.plugins import trigger_external_integrations +from utils.memories.process_memory import process_memory +from utils.other.storage import upload_postprocessing_audio, delete_postprocessing_audio +from utils.stt.pre_recorded import fal_whisperx router = APIRouter() -def _process_get_memory_structered(memory: Union[Memory, CreateMemory, WorkflowCreateMemory], language_code: str, force_process: bool) -> (Structured, bool): - # From workflow - if memory.source == MemorySource.workflow: - if memory.text_source == WorkflowMemorySource.audio: - structured = get_transcript_structure( - memory.text, memory.started_at, language_code, True) - return (structured, False) - - if memory.text_source == WorkflowMemorySource.other: - structured = summarize_experience_text(memory.text) - return (structured, False) - - # not workflow memory source support - raise HTTPException(status_code=400, detail='Invalid workflow memory source') - - # Default source - should_clean_photos = False - transcript = memory.get_transcript() - if memory.photos: - structured: Structured = summarize_open_glass(memory.photos) - should_clean_photos = True # Clear photos to avoid saving them in the memory - return (structured, should_clean_photos) - - structured: Structured = get_transcript_structure( - transcript, memory.started_at, language_code, force_process - ) - - return (structured, should_clean_photos) - - -def process_memory( - uid: str, language_code: str, memory: Union[Memory, CreateMemory, WorkflowCreateMemory], force_process: bool = False, retries: int = 1 -): - # make structured - structured: Structured - try: - (structured, should_clean_photos) = _process_get_memory_structered(memory, language_code, force_process) - if should_clean_photos: - memory.photos = [] - except Exception as e: - print(e) - if retries == 2: - raise HTTPException(status_code=500, detail="Error processing memory, please try again later") - return process_memory(uid, language_code, memory, force_process, retries + 1) - - discarded = structured.title == '' - - # new if - new_photos = [] - if isinstance(memory, CreateMemory): - memory = Memory( - id=str(uuid.uuid4()), - **memory.dict(), - created_at=datetime.utcnow(), - deleted=False, - structured=structured, - discarded=discarded, - ) - new_photos = memory.photos - elif isinstance(memory, WorkflowCreateMemory): - create_memory = memory - memory = Memory( - id=str(uuid.uuid4()), - **memory.dict(), - created_at=datetime.utcnow(), - deleted=False, - structured=structured, - discarded=discarded, - ) - memory.external_data = create_memory.dict() - else: - print(f"Existing memory {memory.id}") - - # store to db - memories_db.upsert_memory(uid, memory.dict()) - # photos - if new_photos: - memories_db.store_memory_photos(uid, memory.id, new_photos) - - # afterward, should be async - asyncio.run(_process_memory_afterward(uid, memory)) - +def _get_memory_by_id(uid: str, memory_id: str): + memory = memories_db.get_memory(uid, memory_id) + if memory is None or memory.get('deleted', False): + raise HTTPException(status_code=404, detail="Memory not found") return memory -def _process_get_memory_conversation_str(memory: Memory) -> str: - # Workflow - if memory.source == MemorySource.workflow: - return memory.external_data["text"] - - # Default - return memory.get_transcript() - - -async def _process_memory_afterward(uid: str, memory: Memory): - if memory.discarded: - return - - structured = memory.structured - transcript = _process_get_memory_conversation_str(memory) - - # forward to plugin - structured_str = str(structured) - vector = generate_embedding(structured_str) - upsert_vector(uid, memory, vector) - - plugins: List[Plugin] = get_plugins_data(uid, include_reviews=False) - filtered_plugins = [plugin for plugin in plugins if plugin.works_with_memories() and plugin.enabled and plugin.trigger_workflow_memories] - threads = [] - - def execute_plugin(plugin): - if result := get_plugin_result(transcript, plugin).strip(): - memory.plugins_results.append(PluginResult(plugin_id=plugin.id, content=result)) - - for plugin in filtered_plugins: - threads.append(threading.Thread(target=execute_plugin, args=(plugin,))) - - [t.start() for t in threads] - [t.join() for t in threads] - - return - - @router.post("/v1/memories", response_model=CreateMemoryResponse, tags=['memories']) def create_memory( create_memory: CreateMemory, trigger_integrations: bool, language_code: Optional[str] = None, uid: str = Depends(auth.get_current_user_uid) ): + """ + Create Memory endpoint. + :param create_memory: data to create memory + :param trigger_integrations: determine if triggering the on_memory_created plugins webhooks. + :param language_code: language. + :param uid: user id. + :return: The new memory created + any messages triggered by on_memory_created integrations. + + TODO: Should receive raw segments by deepgram, instead of the beautified ones? and get beautified on read? + """ if not create_memory.transcript_segments and not create_memory.photos: raise HTTPException(status_code=400, detail="Transcript segments or photos are required") @@ -170,10 +60,99 @@ def create_memory( return CreateMemoryResponse(memory=memory, messages=messages) +@router.post("/v1/memories/{memory_id}/post-processing", response_model=Memory, tags=['memories']) +async def postprocess_memory( + memory_id: str, file: Optional[UploadFile], uid: str = Depends(auth.get_current_user_uid) +): + """ + The objective of this endpoint, is to get the best possible transcript from the audio file. + Instead of storing the initial deepgram result, doing a full post-processing with whisper-x. + This increases the quality of transcript by at least 20%. + Which also includes a better summarization. + Which helps us create better vectors for the memory. + And improves the overall experience of the user. + + TODO: Try Nvidia Nemo ASR as suggested by @jhonnycombs + https://huggingface.co/spaces/hf-audio/open_asr_leaderboard + + TODO: USE soniox here? with speech profile and stuff? + """ + memory_data = _get_memory_by_id(uid, memory_id) + memory = Memory(**memory_data) + if memory.discarded: + raise HTTPException(status_code=400, detail="Memory is discarded") + + if memory.postprocessing is not None: + raise HTTPException(status_code=400, detail="Memory can't be post-processed again") + + # TODO: can do VAD and still keep segments? ~ should do something even with VAD start end? + file_path = f"_temp/{memory_id}_{file.filename}" + with open(file_path, 'wb') as f: + f.write(file.file.read()) + + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) + + # TODO: try https://dev.hume.ai/reference/expression-measurement-api/batch/start-inference-job-from-local-file + + try: + # Upload to GCP + remove file locally and cloud storage + url = upload_postprocessing_audio(file_path) + duration = AudioSegment.from_wav(file_path).duration_seconds + os.remove(file_path) + segments = fal_whisperx(url, duration) + delete_postprocessing_audio(file_path) + # TODO: should consider storing non beautified segments, and beautify on read? + + # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL + count = len(''.join([segment.text for segment in memory.transcript_segments])) + new_count = len(''.join([segment.text for segment in segments])) + if new_count < (count * 0.9): + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) + raise HTTPException(status_code=500, detail="Post-processed transcript is too short") + + + # Fix user speaker_id matching + if any(segment.is_user for segment in memory.transcript_segments): + # TODO: speech profile here using better solutions, and using existing audio file. Speechbrain? + prev = TranscriptSegment.segments_as_string(memory.transcript_segments, False) + transcript_tokens = num_tokens_from_string( + TranscriptSegment.segments_as_string(memory.transcript_segments, False)) + # should limit a few segments, like first and last 100? + if transcript_tokens < 40000: # 40k tokens, costs about 10 usd per request + new = TranscriptSegment.segments_as_string(segments, False) + speaker_id: int = transcript_user_speech_fix(prev, new) + else: # simple way (this in theory should work for all) ~ Not super accurate most likely + speaker_id: int = [segment.speaker_id for segment in memory.transcript_segments if segment.is_user][0] + + for segment in segments: + if segment.speaker_id == speaker_id: + segment.is_user = True + + # Store previous and new segments in DB as collection. + memories_db.store_model_segments_result(uid, memory.id, 'deepgram_streaming', memory.transcript_segments) + memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', segments) + memory.transcript_segments = segments + memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later + + # Reprocess memory with improved transcription + result = process_memory(uid, memory.language, memory, force_process=True) + except Exception as e: + print(e) + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed) + raise HTTPException(status_code=500, detail=str(e)) + + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) + return result + + @router.post('/v1/memories/{memory_id}/reprocess', response_model=Memory, tags=['memories']) def reprocess_memory( memory_id: str, language_code: Optional[str] = None, uid: str = Depends(auth.get_current_user_uid) ): + """ + Whenever a user wants to reprocess a memory, or wants to force process a discarded one + :return: The updated memory after reprocessing. + """ memory = memories_db.get_memory(uid, memory_id) if memory is None: raise HTTPException(status_code=404, detail="Memory not found") @@ -190,13 +169,6 @@ def get_memories(limit: int = 100, offset: int = 0, uid: str = Depends(auth.get_ return memories_db.get_memories(uid, limit, offset, include_discarded=True) -def _get_memory_by_id(uid: str, memory_id: str): - memory = memories_db.get_memory(uid, memory_id) - if memory is None or memory.get('deleted', False): - raise HTTPException(status_code=404, detail="Memory not found") - return memory - - @router.get("/v1/memories/{memory_id}", response_model=Memory, tags=['memories']) def get_memory_by_id(memory_id: str, uid: str = Depends(auth.get_current_user_uid)): return _get_memory_by_id(uid, memory_id) @@ -213,151 +185,3 @@ def delete_memory(memory_id: str, uid: str = Depends(auth.get_current_user_uid)) memories_db.delete_memory(uid, memory_id) delete_vector(memory_id) return {"status": "Ok"} - - -# ************************************************ -# ************ Migrate Local Memories ************ -# ************************************************ - - -def _get_structured(memory: dict): - category = memory['structured']['category'] - if category not in CategoryEnum.__members__: - category = 'other' - emoji = memory['structured'].get('emoji') - try: - emoji = emoji.encode('latin1').decode('utf-8') - except: - emoji = random.choice(['🧠', '🎉']) - - return Structured( - title=memory['structured']['title'], - overview=memory['structured']['overview'], - emoji=emoji, - category=CategoryEnum[category], - action_items=[ - ActionItem(description=description, completed=False) for description in - memory['structured']['actionItems'] - ], - events=[ - Event( - title=event['title'], - description=event['description'], - start=datetime.fromisoformat(event['startsAt']), - duration=event['duration'], - created=False, - ) for event in memory['structured']['events'] - ], - ) - - -def _get_geolocation(memory: dict): - geolocation = memory.get('geoLocation', {}) - if geolocation and geolocation.get('googlePlaceId'): - geolocation_obj = Geolocation( - google_place_id=geolocation['googlePlaceId'], - latitude=geolocation['latitude'], - longitude=geolocation['longitude'], - address=geolocation['address'], - location_type=geolocation['locationType'], - ) - else: - geolocation_obj = None - return geolocation_obj - - -def generate_uuid4_from_seed(seed): - # Use SHA-256 to hash the seed - hash_object = hashlib.sha256(seed.encode('utf-8')) - hash_digest = hash_object.hexdigest() - return uuid.UUID(hash_digest[:32]) - - -def upload_memory_vectors(uid: str, memories: List[Memory]): - if not memories: - return - vectors = [generate_embedding(str(memory.structured)) for memory in memories] - upsert_vectors(uid, vectors, memories) - - -@router.post('/v1/migration/memories', tags=['v1']) -def migrate_local_memories(memories: List[dict], uid: str = Depends(auth.get_current_user_uid)): - if not memories: - return {'status': 'ok'} - memories_vectors = [] - db_batch = memories_db.get_memories_batch_operation() - for i, memory in enumerate(memories): - if memory.get('photos'): - continue # Ignore openGlass memories for now - - structured_obj = _get_structured(memory) - # print(structured_obj) - if not memory['transcriptSegments'] and memory['transcript']: - memory['transcriptSegments'] = [{'text': memory['transcript']}] - - memory_obj = Memory( - id=str(generate_uuid4_from_seed(f'{uid}-{memory["createdAt"]}')), - uid=uid, - structured=structured_obj, - created_at=datetime.fromisoformat(memory['createdAt']), - started_at=datetime.fromisoformat(memory['startedAt']) if memory['startedAt'] else None, - finished_at=datetime.fromisoformat(memory['finishedAt']) if memory['finishedAt'] else None, - discarded=memory['discarded'], - transcript_segments=[ - TranscriptSegment( - text=segment['text'], - start=segment.get('start', 0), - end=segment.get('end', 0), - speaker=segment.get('speaker', 'SPEAKER_00'), - is_user=segment.get('is_user', False), - ) for segment in memory['transcriptSegments'] if segment.get('text', '') - ], - plugins_results=[ - PluginResult(plugin_id=result.get('pluginId'), content=result['content']) - for result in memory['pluginsResponse'] - ], - # photos=[ - # MemoryPhoto(description=photo['description'], base64=photo['base64']) for photo in memory['photos'] - # ], - geolocation=_get_geolocation(memory), - deleted=False, - ) - memories_db.add_memory_to_batch(db_batch, uid, memory_obj.dict()) - - if not memory_obj.discarded: - memories_vectors.append(memory_obj) - - if i % 10 == 0: - threading.Thread(target=upload_memory_vectors, args=(uid, memories_vectors[:])).start() - memories_vectors = [] - - if i % 20 == 0: - db_batch.commit() - db_batch = memories_db.get_memories_batch_operation() - - db_batch.commit() - threading.Thread(target=upload_memory_vectors, args=(uid, memories_vectors[:])).start() - return {} - -# Future dailySummaryNotifications(List memories) async { -# var msg = 'There were no memories today, don\'t forget to wear your Friend tomorrow 😁'; -# if (memories.isEmpty) return msg; -# if (memories.where((m) => !m.discarded).length <= 1) return msg; -# var str = SharedPreferencesUtil().givenName.isEmpty ? 'the user' : SharedPreferencesUtil().givenName; -# var prompt = ''' -# The following are a list of $str\'s memories from today, with the transcripts with its respective structuring, that $str had during his day. -# $str wants to get a summary of the key action items he has to take based on his today's memories. -# -# Remember $str is busy so this has to be very efficient and concise. -# Respond in at most 50 words. -# -# Output your response in plain text, without markdown. -# ``` -# ${Memory.memoriesToString(memories, includeTranscript: true)} -# ``` -# '''; -# debugPrint(prompt); -# var result = await executeGptPrompt(prompt); -# debugPrint('dailySummaryNotifications result: $result'); -# return result.replaceAll('```', '').trim(); -# } diff --git a/backend/routers/notifications.py b/backend/routers/notifications.py index 87af27045..c75789013 100644 --- a/backend/routers/notifications.py +++ b/backend/routers/notifications.py @@ -1,11 +1,9 @@ -import logging - from fastapi import APIRouter, Depends from firebase_admin import messaging import database.notifications as notification_db from models.other import SaveFcmTokenRequest -from utils import auth +from utils.other import endpoints as auth # logger = logging.getLogger('uvicorn.error') # logger.setLevel(logging.DEBUG) diff --git a/backend/routers/plugins.py b/backend/routers/plugins.py index 2b0342d29..7bfb41fe5 100644 --- a/backend/routers/plugins.py +++ b/backend/routers/plugins.py @@ -4,9 +4,9 @@ from fastapi import APIRouter, HTTPException, Depends from models.plugin import Plugin -from utils import auth +from utils.other import endpoints as auth from utils.plugins import get_plugins_data, get_plugin_by_id -from utils.redis_utils import set_plugin_review, enable_plugin, disable_plugin +from database.redis_db import set_plugin_review, enable_plugin, disable_plugin router = APIRouter() diff --git a/backend/routers/screenpipe.py b/backend/routers/screenpipe.py index 0977e6504..aa56608b0 100644 --- a/backend/routers/screenpipe.py +++ b/backend/routers/screenpipe.py @@ -21,7 +21,7 @@ def create_memory(request: Request, uid: str, data: ScreenPipeCreateMemory): if data.source == 'screen': structured = summarize_screen_pipe(data.text) elif data.source == 'audio': - structured = get_transcript_structure(data.text, datetime.utcnow(), 'en', True) + structured = get_transcript_structure(data.text, datetime.utcnow(), 'en') else: raise HTTPException(status_code=400, detail='Invalid memory source') diff --git a/backend/routers/speech_profile.py b/backend/routers/speech_profile.py index ddd7b8db7..098f2e845 100644 --- a/backend/routers/speech_profile.py +++ b/backend/routers/speech_profile.py @@ -4,89 +4,13 @@ from pydub import AudioSegment from models.other import UploadProfile -from utils import auth -from utils.redis_utils import store_user_speech_profile, store_user_speech_profile_duration, get_user_speech_profile -from utils.storage import retrieve_all_samples, upload_sample_storage, upload_profile_audio +from utils.other import endpoints as auth +from database.redis_db import store_user_speech_profile, store_user_speech_profile_duration, get_user_speech_profile +from utils.other.storage import upload_profile_audio router = APIRouter() -def _endpoint1(file, uid): - print('upload_sample') - path = f"_temp/{uid}" - os.makedirs(path, exist_ok=True) - file_path = f"{path}/{file.filename}" - with open(file_path, 'wb') as f: - f.write(file.file.read()) - uploaded_url, count = upload_sample_storage(file_path, uid) - return {"url": uploaded_url} - - -@router.post('/samples/upload') -def upload_sample(file: UploadFile, uid: str): - return _endpoint1(file, uid) - - -def _endpoint2(uid: str): - print('my_samples') - samples_dir = retrieve_all_samples(uid) - samples = set(os.listdir(samples_dir)) - phrases = [ - "I scream, you scream, we all scream for ice cream.", - "Pack my box with five dozen liquor jugs.", - "The five boxing wizards jump quickly and quietly.", - "Bright blue birds fly above the green grassy hills.", - "Fred's friends fried Fritos for Friday's food festival.", - "How much wood would a woodchuck chuck if a woodchuck could chuck wood?", - ] - data = [] - for phrase in phrases: - pid = phrase.replace(' ', '-').replace(',', '').replace('.', '').replace('\'', '').lower() - data.append({'id': pid, 'phrase': phrase, 'uploaded': f"{pid}.wav" in samples}) - - # for file in os.listdir(samples_dir): - # os.remove(f"{samples_dir}/{file}") - return data - - -def _has_speech_profile(uid: str): - data = _endpoint2(uid) - return sum([1 for d in data if d['uploaded']]) >= 5 - - -@router.get('/samples') -def my_samples(uid: str): - return _endpoint2(uid) - - -@router.get('/v1/speech-profile', tags=['v1']) -def has_speech_profile(uid: str): - return {'has_profile': _has_speech_profile(uid)} - - -# ******************** -# * Latest endpoints * -# ******************** -@router.post('/v1/speech-profile/samples', tags=['v1']) -def upload_sample(file: UploadFile, uid: str = Depends(auth.get_current_user_uid)): - return _endpoint1(file, uid) - - -@router.get('/v1/speech-profile/samples', tags=['v1']) -def my_samples(uid: str = Depends(auth.get_current_user_uid)): - return _endpoint2(uid) - - -@router.get('/v2/speech-profile', tags=['v1']) -def has_speech_profile(uid: str = Depends(auth.get_current_user_uid)): - return {'has_profile': _has_speech_profile(uid)} - - -# ********************** -# * Latest endpoints 2 * -# ********************** - - @router.get('/v3/speech-profile', tags=['v3']) def has_speech_profile(uid: str = Depends(auth.get_current_user_uid)): return {'has_profile': len(get_user_speech_profile(uid)) > 0} @@ -104,8 +28,6 @@ def upload_profile(data: UploadProfile, uid: str = Depends(auth.get_current_user return {'status': 'ok'} -# TODO: app improvement, if speaker 0 starts speaking, which is not the user, after killing ws 2, it will say speaker 0 is the user - @router.post('/v3/upload-audio', tags=['v3']) def upload_profile(file: UploadFile, uid: str = Depends(auth.get_current_user_uid)): os.makedirs(f'_temp/{uid}', exist_ok=True) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 0536a398b..416e53c57 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -1,40 +1,37 @@ import asyncio -import os import time -import uuid -from fastapi import APIRouter, UploadFile +from fastapi import APIRouter from fastapi.websockets import (WebSocketDisconnect, WebSocket) -from pydub import AudioSegment from starlette.websockets import WebSocketState import torch from collections import deque import opuslib -from utils.redis_utils import get_user_speech_profile, get_user_speech_profile_duration -from utils.stt.deepgram_util import process_audio_dg, send_initial_file2, transcribe_file_deepgram -from utils.stt.vad import VADIterator, model, vad_is_empty, is_speech_present +from database.redis_db import get_user_speech_profile, get_user_speech_profile_duration +from utils.stt.streaming import process_audio_dg, send_initial_file +from utils.stt.vad import VADIterator, model, is_speech_present router = APIRouter() # @router.post("/v1/transcribe", tags=['v1']) # will be used again in Friend V2 -def transcribe_auth(file: UploadFile, uid: str, language: str = 'en'): - upload_id = str(uuid.uuid4()) - file_path = f"_temp/{upload_id}_{file.filename}" - with open(file_path, 'wb') as f: - f.write(file.file.read()) - - aseg = AudioSegment.from_wav(file_path) - print(f'Transcribing audio {aseg.duration_seconds} secs and {aseg.frame_rate / 1000} khz') - - if vad_is_empty(file_path): # TODO: get vad segments - os.remove(file_path) - return [] - transcript = transcribe_file_deepgram(file_path, language=language) - os.remove(file_path) - return transcript +# def transcribe_auth(file: UploadFile, uid: str, language: str = 'en'): +# upload_id = str(uuid.uuid4()) +# file_path = f"_temp/{upload_id}_{file.filename}" +# with open(file_path, 'wb') as f: +# f.write(file.file.read()) +# +# aseg = AudioSegment.from_wav(file_path) +# print(f'Transcribing audio {aseg.duration_seconds} secs and {aseg.frame_rate / 1000} khz') +# +# if vad_is_empty(file_path): # TODO: get vad segments +# os.remove(file_path) +# return [] +# transcript = transcribe_file_deepgram(file_path, language=language) +# os.remove(file_path) +# return transcript # templates = Jinja2Templates(directory="templates") @@ -70,7 +67,7 @@ async def _websocket_util( preseconds=duration) if duration: transcript_socket2 = await process_audio_dg(uid, websocket, language, sample_rate, codec, channels) - await send_initial_file2(speech_profile, transcript_socket) + await send_initial_file(speech_profile, transcript_socket) except Exception as e: print(f"Initial processing error: {e}") diff --git a/backend/routers/workflow.py b/backend/routers/workflow.py index f6a3c7333..77afb35f4 100644 --- a/backend/routers/workflow.py +++ b/backend/routers/workflow.py @@ -6,7 +6,7 @@ import models.memory as memory_models import models.integrations as integration_models -from utils.location import get_google_maps_location +from utils.memories.location import get_google_maps_location from routers.memories import process_memory, trigger_external_integrations router = APIRouter() diff --git a/backend/scripts/stt/h_brainstorming.py b/backend/scripts/stt/h_brainstorming.py index 3caa181b7..95a1be9f0 100644 --- a/backend/scripts/stt/h_brainstorming.py +++ b/backend/scripts/stt/h_brainstorming.py @@ -7,7 +7,7 @@ from groq import Groq from openai import OpenAI -from utils.endpoints import timeit +from utils.other.endpoints import timeit os.environ['GROQ_API_KEY'] = '' os.environ['FAL_KEY'] = '' diff --git a/backend/utils/preprocess.py b/backend/utils/@deprecated/preprocess.py similarity index 100% rename from backend/utils/preprocess.py rename to backend/utils/@deprecated/preprocess.py diff --git a/backend/utils/stt/soniox_util.py b/backend/utils/@deprecated/soniox_util.py similarity index 100% rename from backend/utils/stt/soniox_util.py rename to backend/utils/@deprecated/soniox_util.py diff --git a/backend/utils/speaker_profile.py b/backend/utils/@deprecated/speaker_profile.py similarity index 100% rename from backend/utils/speaker_profile.py rename to backend/utils/@deprecated/speaker_profile.py diff --git a/backend/utils/stt/whisper.py b/backend/utils/@deprecated/whisper.py similarity index 100% rename from backend/utils/stt/whisper.py rename to backend/utils/@deprecated/whisper.py diff --git a/backend/utils/stt/whisper_x.py b/backend/utils/@deprecated/whisper_x.py similarity index 100% rename from backend/utils/stt/whisper_x.py rename to backend/utils/@deprecated/whisper_x.py diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/utils/auth.py b/backend/utils/auth.py deleted file mode 100644 index be7bf931a..000000000 --- a/backend/utils/auth.py +++ /dev/null @@ -1,26 +0,0 @@ -import os - -from fastapi import Header, HTTPException -from firebase_admin import auth -from firebase_admin.auth import InvalidIdTokenError - - -def get_current_user_uid(authorization: str = Header(None)): - if os.getenv('ADMIN_KEY') in authorization: - return authorization.split(os.getenv('ADMIN_KEY'))[1] - - if not authorization: - raise HTTPException(status_code=401, detail="Authorization header not found") - elif len(str(authorization).split(' ')) != 2: - raise HTTPException(status_code=401, detail="Invalid authorization token") - - try: - token = authorization.split(' ')[1] - decoded_token = auth.verify_id_token(token) - print('get_current_user_uid', decoded_token['uid']) - return decoded_token['uid'] - except InvalidIdTokenError as e: - if os.getenv('LOCAL_DEVELOPMENT') == 'true': - return '123' - print(e) - raise HTTPException(status_code=401, detail="Invalid authorization token") diff --git a/backend/utils/llm.py b/backend/utils/llm.py index 84d72fa7a..29eac3a5a 100644 --- a/backend/utils/llm.py +++ b/backend/utils/llm.py @@ -1,25 +1,13 @@ import json -import os from datetime import datetime from typing import List, Tuple, Optional -from langchain.agents import create_tool_calling_agent, AgentExecutor -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.chains.history_aware_retriever import create_history_aware_retriever -from langchain.chains.retrieval import create_retrieval_chain -from langchain.output_parsers import BooleanOutputParser -from langchain_community.chat_message_histories import ChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.output_parsers import PydanticOutputParser -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate -from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_core.tools import create_retriever_tool +from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_pinecone import PineconeVectorStore from pydantic import BaseModel, Field -from models.chat import Message, MessageSender +from models.chat import Message from models.memory import Structured, MemoryPhoto from models.plugin import Plugin from models.transcript_segment import TranscriptSegment, ImprovedTranscript @@ -30,10 +18,6 @@ llm_with_parser = llm.with_structured_output(Structured) -# groq_llm = llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_retries=2) -# groq_llm_with_parser = groq_llm.with_structured_output(Structured) - - # TODO: include caching layer, redis @@ -72,43 +56,46 @@ def improve_transcript_prompt(segments: List[TranscriptSegment]) -> List[Transcr return response.result -def discard_memory(transcript: str) -> bool: +class DiscardMemory(BaseModel): + discard: bool = Field(description="If the memory should be discarded or not") + + +class SpeakerIdMatch(BaseModel): + speaker_id: int = Field(description="The speaker id assigned to the segment") + + +# ********************************************** +# ************* MEMORY PROCESSING ************** +# ********************************************** + +def should_discard_memory(transcript: str) -> bool: if len(transcript.split(' ')) > 100: return False - parser = BooleanOutputParser() + parser = PydanticOutputParser(pydantic_object=DiscardMemory) prompt = ChatPromptTemplate.from_messages([ ''' - You will be given a conversation transcript, and your task is to determine if the conversation is not worth storing as a memory or not. - It is not worth storing if there are no interesting topics, facts, or information, in that case, output True. + You will be given a conversation transcript, and your task is to determine if the conversation is worth storing as a memory or not. + It is not worth storing if there are no interesting topics, facts, or information, in that case, output discard = True. Transcript: ```{transcript}``` {format_instructions}'''.replace(' ', '').strip() ]) - print(prompt) chain = prompt | llm | parser try: - response = chain.invoke({ + response: DiscardMemory = chain.invoke({ 'transcript': transcript.strip(), 'format_instructions': parser.get_format_instructions(), }) - return response + return response.discard + except Exception as e: print(f'Error determining memory discard: {e}') return False -def get_transcript_structure( - transcript: str, started_at: datetime, language_code: str, force_process: bool -) -> Structured: - if len(transcript.split(' ')) > 100: - force_process = True - - force_process_str = '' - if not force_process: - force_process_str = 'It is possible that the conversation is not worth storing, there are no interesting topics, facts, or information, in that case, output an empty title, overview, and action items.' - +def get_transcript_structure(transcript: str, started_at: datetime, language_code: str) -> Structured: prompt = ChatPromptTemplate.from_messages([( 'system', '''Your task is to provide structure and clarity to the recording transcription of a conversation. @@ -126,21 +113,62 @@ def get_transcript_structure( {format_instructions}'''.replace(' ', '').strip() )]) - # if use_cheaper_model: - # chain = prompt | groq_llm | parser - # else: chain = prompt | llm | parser response = chain.invoke({ 'transcript': transcript.strip(), 'format_instructions': parser.get_format_instructions(), 'language_code': language_code, - 'force_process_str': force_process_str, + 'force_process_str': '', 'started_at': started_at.isoformat(), }) return response +def transcript_user_speech_fix(prev_transcript: str, new_transcript: str) -> int: + print(f'transcript_user_speech_fix prev_transcript: {len(prev_transcript)} new_transcript: {len(new_transcript)}') + prompt = f''' + You will be given a previous transcript and a improved transcript, previous transcript has the user voice identified, but the improved transcript does not have it. + Your task is to determine on the improved transcript, which speaker id corresponds to the user voice, based on the previous transcript. + + Previous Transcript: + {prev_transcript} + + Improved Transcript: + {new_transcript} + ''' + with_parser = llm.with_structured_output(SpeakerIdMatch) + response: SpeakerIdMatch = with_parser.invoke(prompt) + return response.speaker_id + + +def get_plugin_result(transcript: str, plugin: Plugin) -> str: + prompt = f''' + Your are an AI with the following characteristics: + Name: ${plugin.name}, + Description: ${plugin.description}, + Task: ${plugin.memory_prompt} + + Note: It is possible that the conversation you are given, has nothing to do with your task, \ + in that case, output an empty string. (For example, you are given a business conversation, but your task is medical analysis) + + Conversation: ```{transcript.strip()}```, + + Output your response in plain text, without markdown. + Make sure to be concise and clear. + ''' + + response = llm.invoke(prompt) + content = response.content.replace('```json', '').replace('```', '') + if len(content) < 5: + return '' + return content + + +# ************************************** +# ************* OPENGLASS ************** +# ************************************** + def summarize_open_glass(photos: List[MemoryPhoto]) -> Structured: photos_str = '' for i, photo in enumerate(photos): @@ -156,6 +184,10 @@ def summarize_open_glass(photos: List[MemoryPhoto]) -> Structured: return llm_with_parser.invoke(prompt) +# ************************************************** +# ************* EXTERNAL INTEGRATIONS ************** +# ************************************************** + def summarize_screen_pipe(description: str) -> Structured: prompt = f'''The user took a series of screenshots from his laptop, and used OCR to obtain the text from the screen. @@ -182,150 +214,10 @@ def summarize_experience_text(text: str) -> Structured: return llm_with_parser.invoke(prompt) -def get_plugin_result(transcript: str, plugin: Plugin) -> str: - prompt = f''' - Your are an AI with the following characteristics: - Name: ${plugin.name}, - Description: ${plugin.description}, - Task: ${plugin.memory_prompt} - - Note: It is possible that the conversation you are given, has nothing to do with your task, \ - in that case, output an empty string. (For example, you are given a business conversation, but your task is medical analysis) - - Conversation: ```{transcript.strip()}```, - - Output your response in plain text, without markdown. - Make sure to be concise and clear. - ''' - - response = llm.invoke(prompt) - content = response.content.replace('```json', '').replace('```', '') - if len(content) < 5: - return '' - return content - - def generate_embedding(content: str) -> List[float]: return embeddings.embed_documents([content])[0] -# ****************************************** -# ************** CHAT AGENT **************** -# ****************************************** - - -def _get_retriever(): - vectordb = PineconeVectorStore( - index_name=os.getenv('PINECONE_INDEX_NAME'), - pinecone_api_key=os.getenv('PINECONE_API_KEY'), - embedding=OpenAIEmbeddings(), - ) - # TODO: maybe try mmr later, but similarity works great, llm aided is not possible here, no metadata. - # can tweak the number of docs to retrieve - return vectordb.as_retriever(search_type="similarity", search_kwargs={"k": 10}) - - -def get_chat_history(messages: List[Message]) -> BaseChatMessageHistory: - history = ChatMessageHistory() - for message in messages: - if message.sender == MessageSender.human: - history.add_message(HumanMessage(content=message.text)) - else: - history.add_message(AIMessage(content=message.text)) - return history - - -# CHAIN -def _get_context_question(): - contextualize_q_system_prompt = """Given a chat history and the latest user question \ - which might reference context in the chat history, formulate a standalone question \ - which can be understood without the chat history. Do NOT answer the question, \ - just reformulate it if needed and otherwise return it as is.""" - contextualize_q_prompt = ChatPromptTemplate.from_messages( - [ - ("system", contextualize_q_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] - ) - return create_history_aware_retriever(llm, _get_retriever(), contextualize_q_prompt) - - -def chat_qa_chain(uid: str, messages: List[Message]): - qa_system_prompt = """You are an assistant for question-answering tasks. \ - Use the following pieces of retrieved context to answer the question. \ - If you don't know the answer, just say that you don't have access to that information. \ - Use three sentences maximum and keep the answer concise.\ - - {context}""" - qa_prompt = ChatPromptTemplate.from_messages( - [ - ("system", qa_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] - ) - question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) - history_aware_retriever = _get_context_question() - rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - - def get_session(): - return get_chat_history(messages) - - conversational_rag_chain = RunnableWithMessageHistory( - rag_chain, - get_session, - input_messages_key="input", - history_messages_key="chat_history", - output_messages_key="answer", - ) - return conversational_rag_chain.stream( - {"input": "What are common ways of doing it?"}, - config={"configurable": {"session_id": uid}}, - ) - - -# ************************************************* -# ************* AGENT RETRIEVER TOOL ************** -# ************************************************* - -def _get_init_prompt(): - return ChatPromptTemplate.from_messages([ - SystemMessage(content=f''' - You are an assistant for question-answering tasks. Use the following pieces of retrieved context and the conversation history to continue the conversation. - If you don't know the answer, just say that you didn't find any related information or you that don't know. Use three sentences maximum and keep the answer concise. - If the message doesn't require context, it will be empty, so answer the question casually. - '''), - MessagesPlaceholder(variable_name="chat_history"), - HumanMessagePromptTemplate.from_template("{input}"), - MessagesPlaceholder(variable_name="agent_scratchpad"), - ]) - - -def _agent_with_retriever_tool(messages: List[Message]): - tool = create_retriever_tool( - _get_retriever(), - "conversations_retriever", - "Searches for relevant conversations the user has had in the past.", - ) - agent = create_tool_calling_agent(llm, [tool], _get_init_prompt()) - agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=[tool], verbose=True) - return RunnableWithMessageHistory( - agent_executor, - lambda session_id: get_chat_history(messages), - input_messages_key="input", - output_messages_key="output", - history_messages_key="chat_history", - ) - - -def ask_agent(message: str, messages: List[Message]): - agent = _agent_with_retriever_tool(messages) - output = agent.invoke({'input': HumanMessage(content=message)}, - {"configurable": {"session_id": "unused"}}) - return output['output'] - - # *************************************************** # ************* CHAT CURRENT APP LOGIC ************** # *************************************************** @@ -430,3 +322,14 @@ def initial_chat_message(plugin: Optional[Plugin] = None) -> str: ''' prompt = prompt.replace(' ', '').strip() return llm.invoke(prompt).content + + +import tiktoken + +encoding = tiktoken.encoding_for_model('gpt-4') + + +def num_tokens_from_string(string: str) -> int: + """Returns the number of tokens in a text string.""" + num_tokens = len(encoding.encode(string)) + return num_tokens diff --git a/backend/utils/location.py b/backend/utils/memories/location.py similarity index 100% rename from backend/utils/location.py rename to backend/utils/memories/location.py diff --git a/backend/utils/process_memory.py b/backend/utils/memories/process_memory.py similarity index 51% rename from backend/utils/process_memory.py rename to backend/utils/memories/process_memory.py index 6d2bdb9fe..99c23d60f 100644 --- a/backend/utils/process_memory.py +++ b/backend/utils/memories/process_memory.py @@ -1,3 +1,4 @@ +import random import threading import uuid from typing import Union @@ -5,41 +6,53 @@ from fastapi import HTTPException import database.memories as memories_db -from database.vector import upsert_vector +from database.vector_db import upsert_vector from models.memory import * from models.plugin import Plugin from utils.llm import summarize_open_glass, get_transcript_structure, generate_embedding, \ - get_plugin_result + get_plugin_result, should_discard_memory, summarize_experience_text from utils.plugins import get_plugins_data def _get_structured( - uid: str, language_code: str, memory: Union[Memory, CreateMemory], force_process: bool = False, retries: int = 1 + uid: str, language_code: str, memory: Union[Memory, CreateMemory, WorkflowCreateMemory], + force_process: bool = False, retries: int = 1 ) -> Structured: - transcript = memory.get_transcript() - # has_audio = isinstance(memory, CreateMemory) and memory.audio_base64_url - try: + if memory.source == MemorySource.workflow: + if memory.text_source == WorkflowMemorySource.audio: + structured = get_transcript_structure(memory.text, memory.started_at, language_code) + return structured, False + + if memory.text_source == WorkflowMemorySource.other: + structured = summarize_experience_text(memory.text) + return structured, False + + # not workflow memory source support + raise HTTPException(status_code=400, detail='Invalid workflow memory source') + + # from OpenGlass if memory.photos: - structured: Structured = summarize_open_glass(memory.photos) - else: - # if has_audio and not discard_memory(transcript): - # # TODO: test a 1h or 2h recording ~ should this be async ~~ also, how long does it take on frontend to upload that size? - # segments = fal_whisperx(memory.audio_base64_url) - # memory.transcript_segments = segments - - structured: Structured = get_transcript_structure( - transcript, memory.started_at, language_code, force_process - ) + return summarize_open_glass(memory.photos), False + + # from Friend + if force_process: + # reprocess endpoint + return get_transcript_structure(memory.get_transcript(False), memory.started_at, language_code), False + + discarded = should_discard_memory(memory.get_transcript(False)) + if discarded: + return Structured(emoji=random.choice(['🧠', '🎉'])), True + + return get_transcript_structure(memory.get_transcript(False), memory.started_at, language_code), False except Exception as e: print(e) if retries == 2: raise HTTPException(status_code=500, detail="Error processing memory, please try again later") return _get_structured(uid, language_code, memory, force_process, retries + 1) - return structured -def _get_memory_obj(uid: str, structured: Structured, memory: Union[Memory, CreateMemory], transcript: str): +def _get_memory_obj(uid: str, structured: Structured, memory: Union[Memory, CreateMemory, WorkflowCreateMemory]): discarded = structured.title == '' if isinstance(memory, CreateMemory): memory = Memory( @@ -48,12 +61,22 @@ def _get_memory_obj(uid: str, structured: Structured, memory: Union[Memory, Crea structured=structured, **memory.dict(), created_at=datetime.utcnow(), - transcript=transcript, discarded=discarded, deleted=False, ) if memory.photos: memories_db.store_memory_photos(uid, memory.id, memory.photos) + elif isinstance(memory, WorkflowCreateMemory): + create_memory = memory + memory = Memory( + id=str(uuid.uuid4()), + **memory.dict(), + created_at=datetime.utcnow(), + deleted=False, + structured=structured, + discarded=discarded, + ) + memory.external_data = create_memory.dict() else: memory.structured = structured memory.discarded = discarded @@ -77,17 +100,16 @@ def execute_plugin(plugin): [t.join() for t in threads] -def process_memory(uid: str, language_code: str, memory: Union[Memory, CreateMemory], force_process: bool = False): - structured: Structured = _get_structured(uid, language_code, memory, force_process) - transcript = memory.get_transcript() - memory = _get_memory_obj(uid, structured, memory, transcript) +def process_memory(uid: str, language_code: str, memory: Union[Memory, CreateMemory, WorkflowCreateMemory], + force_process: bool = False): + structured, discarded = _get_structured(uid, language_code, memory, force_process) + memory = _get_memory_obj(uid, structured, memory) - discarded = structured.title == '' if not discarded: vector = generate_embedding(str(structured)) upsert_vector(uid, memory, vector) - _trigger_plugins(uid, transcript, memory) + _trigger_plugins(uid, memory.get_transcript(False), memory) # async memories_db.upsert_memory(uid, memory.dict()) - print('Memory processed', memory.id) + print('process_memory memory.id=', memory.id) return memory diff --git a/backend/utils/endpoints.py b/backend/utils/other/endpoints.py similarity index 64% rename from backend/utils/endpoints.py rename to backend/utils/other/endpoints.py index 930988454..bfd7ab50b 100644 --- a/backend/utils/endpoints.py +++ b/backend/utils/other/endpoints.py @@ -1,7 +1,33 @@ import json +import os import time -from fastapi import Request, HTTPException +from fastapi import Header, HTTPException +from fastapi import Request +from firebase_admin import auth +from firebase_admin.auth import InvalidIdTokenError + + +def get_current_user_uid(authorization: str = Header(None)): + if os.getenv('ADMIN_KEY') in authorization: + return authorization.split(os.getenv('ADMIN_KEY'))[1] + + if not authorization: + raise HTTPException(status_code=401, detail="Authorization header not found") + elif len(str(authorization).split(' ')) != 2: + raise HTTPException(status_code=401, detail="Invalid authorization token") + + try: + token = authorization.split(' ')[1] + decoded_token = auth.verify_id_token(token) + print('get_current_user_uid', decoded_token['uid']) + return decoded_token['uid'] + except InvalidIdTokenError as e: + if os.getenv('LOCAL_DEVELOPMENT') == 'true': + return '123' + print(e) + raise HTTPException(status_code=401, detail="Invalid authorization token") + cached = {} diff --git a/backend/utils/crons/notifications.py b/backend/utils/other/notifications.py similarity index 100% rename from backend/utils/crons/notifications.py rename to backend/utils/other/notifications.py diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py new file mode 100644 index 000000000..8a2d13faf --- /dev/null +++ b/backend/utils/other/storage.py @@ -0,0 +1,37 @@ +import json +import os + +from google.cloud import storage +from google.oauth2 import service_account + +if os.environ.get('SERVICE_ACCOUNT_JSON'): + service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) + credentials = service_account.Credentials.from_service_account_info(service_account_info) + storage_client = storage.Client(credentials=credentials) +else: + storage_client = storage.Client() + +speech_profiles_bucket = os.getenv('BUCKET_SPEECH_PROFILES') +postprocessing_audio_bucket = os.getenv('BUCKET_POSTPROCESSING') +backups_bucket = os.getenv('BUCKET_BACKUPS') + + +def upload_profile_audio(file_path: str, uid: str): + bucket = storage_client.bucket(speech_profiles_bucket) + path = f'{uid}/speech_profile.wav' + blob = bucket.blob(path) + blob.upload_from_filename(file_path) + return f'https://storage.googleapis.com/{speech_profiles_bucket}/{path}' + + +def upload_postprocessing_audio(file_path: str): + bucket = storage_client.bucket(postprocessing_audio_bucket) + blob = bucket.blob(file_path) + blob.upload_from_filename(file_path) + return f'https://storage.googleapis.com/{postprocessing_audio_bucket}/{file_path}' + + +def delete_postprocessing_audio(file_path: str): + bucket = storage_client.bucket(postprocessing_audio_bucket) + blob = bucket.blob(file_path) + blob.delete() diff --git a/backend/utils/plugins.py b/backend/utils/plugins.py index c9491d577..1458bb0f1 100644 --- a/backend/utils/plugins.py +++ b/backend/utils/plugins.py @@ -9,7 +9,7 @@ from models.memory import Memory, MemorySource from models.plugin import Plugin from routers.notifications import send_notification -from utils.redis_utils import get_enabled_plugins, get_plugin_reviews +from database.redis_db import get_enabled_plugins, get_plugin_reviews def get_plugin_by_id(plugin_id: str) -> Optional[Plugin]: diff --git a/backend/utils/prompt.py b/backend/utils/prompt.py deleted file mode 100644 index 62bbe1742..000000000 --- a/backend/utils/prompt.py +++ /dev/null @@ -1,60 +0,0 @@ -import json - -data = [ - { - "text": "Okay. So I start speaking now. So Karen.", - "speaker": "SPEAKER_0", - "speaker_id": 0, - "is_user": True, - "start": 0.0, - "end": 7.880000000000109 - }, - { - "text": "Okay. So I guess if I continue speaking, 1 will represent strippers percent. It means that it's, like, 35 words and all that. As that decent. I don't still like different person. Think that's better than nothing. At least. Isn't it?", - "speaker": "SPEAKER_0", - "speaker_id": 0, - "is_user": True, - "start": 97.1699000000001, - "end": 167.87000000000012 - }, - { - "text": "Okay. So now when I start speaking, it should be considerably better. Alright. So, yeah, Alister speaking. Alright. So messages will change a little bit. Then it says that You are almost there or you are doing great. This is a cute NPP item. I think this is better all. I kinda like logic. It works. Interesting.", - "speaker": "SPEAKER_0", - "speaker_id": 0, - "is_user": True, - "start": 198.56999999999994, - "end": 279.21000000000004 - }, - { - "text": "Great. Quiet.", - "speaker": "SPEAKER_0", - "speaker_id": 0, - "is_user": True, - "start": 356.24, - "end": 379.9000000000001 - }, - { - "text": "It starts running. Hi.", - "speaker": "SPEAKER_0", - "speaker_id": 0, - "is_user": True, - "start": 419.6298999999999, - "end": 424.23 - } -] - - -def execute(): - cleaned = [] - for item in data: - cleaned.append({ - 'speaker_id': item['speaker_id'] + 1 if not item['is_user'] else 0, - 'text': item['text'], - # 'seconds': [round(item['start'], 2), round(item['end'], 2)] - }) - - print(json.dumps(cleaned, indent=2)) - - -if __name__ == '__main__': - execute() diff --git a/backend/utils/rag.py b/backend/utils/retrieval/rag.py similarity index 96% rename from backend/utils/rag.py rename to backend/utils/retrieval/rag.py index 8fea6ce14..39b29fc0b 100644 --- a/backend/utils/rag.py +++ b/backend/utils/retrieval/rag.py @@ -1,7 +1,7 @@ from typing import List, Tuple from database.memories import filter_memories_by_date, get_memories_by_id -from database.vector import query_vectors +from database.vector_db import query_vectors from models.chat import Message from models.memory import Memory from utils.llm import determine_requires_context diff --git a/backend/utils/storage.py b/backend/utils/storage.py deleted file mode 100644 index e2b39c232..000000000 --- a/backend/utils/storage.py +++ /dev/null @@ -1,81 +0,0 @@ -import json -import os - -from google.cloud import storage -from google.oauth2 import service_account - -if os.environ.get('SERVICE_ACCOUNT_JSON'): - service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) - credentials = service_account.Credentials.from_service_account_info(service_account_info) - storage_client = storage.Client(credentials=credentials) -else: - storage_client = storage.Client() - -speech_profiles_bucket = os.getenv('BUCKET_SPEECH_PROFILES') -backups_bucket = os.getenv('BUCKET_BACKUPS') - - -def upload_profile_audio(file_path: str, uid: str): - bucket = storage_client.bucket(speech_profiles_bucket) - path = f'{uid}/speech_profile.wav' - blob = bucket.blob(path) - blob.upload_from_filename(file_path) - return f'https://storage.googleapis.com/{speech_profiles_bucket}/{path}' - - -# def get_speech_profile(uid: str): -# bucket = storage_client.bucket(speech_profiles_bucket) -# path = f'{uid}/speech_profile.wav' -# blob = bucket.blob(path) -# if not blob.exists(): -# return None -# -# os.makedirs('_speech_profiles/', exist_ok=True) -# profile_path = f'_speech_profiles/{uid}.wav' -# blob.download_to_filename(profile_path) -# return profile_path - - -# *********** -# *** OLD *** -# *********** -# soon to be deprecated - -def upload_sample_storage(file_path: str, uid: str): - print('upload_sample_storage', file_path) - bucket = storage_client.bucket(speech_profiles_bucket) - blobs = bucket.list_blobs(prefix=f'{uid}/samples/') - sample_i = len(list(blobs)) - path = f'{uid}/samples/{file_path.split("/")[-1]}' - blob = bucket.blob(path) - blob.upload_from_filename(file_path) - return f'https://storage.googleapis.com/{speech_profiles_bucket}/{path}', sample_i + 1 - - -def upload_speaker_profile(profile_path: str, uid: str): - print('upload_speaker_profile', profile_path) - bucket = storage_client.bucket(speech_profiles_bucket) - path = f'{uid}/profile.pt' - blob = bucket.blob(path) - blob.upload_from_filename(profile_path) - print('upload_speaker_profile ~ uploaded') - return f'https://storage.googleapis.com/{speech_profiles_bucket}/{path}' - - -def retrieve_all_samples(uid: str): - print('retrieve_all_samples') - # retrieve each of the _samples in the user folder, and store them in _samples/{uid} - bucket = storage_client.bucket(speech_profiles_bucket) - blobs = bucket.list_blobs(prefix=f'{uid}/samples/') - base_path = f'_samples/{uid}/' - os.makedirs(base_path, exist_ok=True) - - for i, blob in enumerate(blobs): - path = f'{base_path}{blob.name.split("/")[-1]}' - if os.path.exists(path): # when opus uploaded? should refresh the download - continue - try: - blob.download_to_filename(path) - except Exception as e: - print(f'Error downloading {blob.name}', e) - return base_path diff --git a/backend/utils/stt/fal.py b/backend/utils/stt/pre_recorded.py similarity index 64% rename from backend/utils/stt/fal.py rename to backend/utils/stt/pre_recorded.py index f891b96bf..474bd7178 100644 --- a/backend/utils/stt/fal.py +++ b/backend/utils/stt/pre_recorded.py @@ -5,6 +5,7 @@ import fal_client from models.transcript_segment import TranscriptSegment +from utils.other.endpoints import timeit def file_to_base64_url(file_path): @@ -31,13 +32,29 @@ def base64_to_file(base64_url, file_path): # Write the content to the file with open(file_path, 'wb') as file: file.write(file_content) + return file.read() -def fal_whisperx(audio_base64_url: str) -> List[TranscriptSegment]: +def upload_fal_file(mid: str, audio_base64_url: str): + print(audio_base64_url) + file_bytes = base64_to_file(audio_base64_url, f"_temp/{mid}.wav") + url = fal_client.upload(file_bytes, "audio/wav") + print('url', url) + return url + + +def delete_fal_file(url: str): + # url = fal_client.de(file_bytes, "audio/wav") + # return url + return False + + +@timeit +def fal_whisperx(audio_url: str, duration: int = None) -> List[TranscriptSegment]: handler = fal_client.submit( "fal-ai/whisper", arguments={ - "audio_url": audio_base64_url, + "audio_url": audio_url, 'task': 'transcribe', 'diarize': True, 'language': 'en', @@ -56,6 +73,7 @@ def fal_whisperx(audio_base64_url: str) -> List[TranscriptSegment]: chunk['end'] = chunk['timestamp'][1] chunk['text'] = chunk['text'].strip() chunk['is_user'] = False + chunk['speaker'] = chunk.get('speaker') or 'SPEAKER_00' # TODO: why is needed? del chunk['timestamp'] cleaned = [] @@ -67,9 +85,15 @@ def fal_whisperx(audio_base64_url: str) -> List[TranscriptSegment]: else: cleaned.append(chunk) - # TODO: Include pipeline post processing, so that is_user get's matched with the correct speaker - # TODO: Do punctuation correction with LLM + segments = [] + for segment in cleaned: + # print(segment['start'], segment['end'], segment['speaker'], segment['text']) + segments.append(TranscriptSegment( + text=segment['text'], + speaker=segment['speaker'], + is_user=segment['is_user'], + start=segment['start'] or 0, + end=segment['end'] or duration or segment['start'] + 1, + )) - # TODO: test other languages - # TODO: eventually do speaker embedding matching - return cleaned + return segments diff --git a/backend/utils/stt/deepgram_util.py b/backend/utils/stt/streaming.py similarity index 68% rename from backend/utils/stt/deepgram_util.py rename to backend/utils/stt/streaming.py index 2bc94510d..ad76f0b87 100644 --- a/backend/utils/stt/deepgram_util.py +++ b/backend/utils/stt/streaming.py @@ -4,7 +4,6 @@ import time from typing import List -import requests from deepgram import DeepgramClient, DeepgramClientOptions, LiveTranscriptionEvents from deepgram.clients.live.v1 import LiveOptions from starlette.websockets import WebSocket @@ -19,63 +18,51 @@ } -def transcribe_file_deepgram(file_path: str, language: str = 'en'): - print('transcribe_file_deepgram', file_path, language) - url = ('https://api.deepgram.com/v1/listen?' - 'model=nova-2-general&' - 'detect_language=false&' - f'language={language}&' - 'filler_words=false&' - 'multichannel=false&' - 'diarize=true&' - 'punctuate=true&' - 'smart_format=true') - - with open(file_path, "rb") as file: - response = requests.post(url, headers=headers, data=file) - - data = response.json() - result = data['results']['channels'][0]['alternatives'][0] - segments = [] - for word in result['words']: - if not segments: - segments.append({ - 'speaker': f"SPEAKER_{word['speaker']}", - 'start': word['start'], - 'end': word['end'], - 'text': word['word'], - 'isUser': False - }) - else: - last_segment = segments[-1] - if last_segment['speaker'] == f"SPEAKER_{word['speaker']}": - last_segment['text'] += f" {word['word']}" - last_segment['end'] = word['end'] - else: - segments.append({ - 'speaker': f"SPEAKER_{word['speaker']}", - 'start': word['start'], - 'end': word['end'], - 'text': word['word'], - 'isUser': False - }) - - return segments - - -# async def send_initial_file(file_path, transcript_socket): +# def transcribe_file_deepgram(file_path: str, language: str = 'en'): +# print('transcribe_file_deepgram', file_path, language) +# url = ('https://api.deepgram.com/v1/listen?' +# 'model=nova-2-general&' +# 'detect_language=false&' +# f'language={language}&' +# 'filler_words=false&' +# 'multichannel=false&' +# 'diarize=true&' +# 'punctuate=true&' +# 'smart_format=true') +# # with open(file_path, "rb") as file: -# data = file.read() -# start = time.time() -# chunk_size = 4096 # Adjust as needed -# for i in range(0, len(data), chunk_size): -# chunk = data[i:i + chunk_size] -# transcript_socket.send(chunk) -# await asyncio.sleep(0.01) # Small delay to prevent overwhelming the socket -# print('send_initial_file', time.time() - start) - - -async def send_initial_file2(data: List[List[int]], transcript_socket): +# response = requests.post(url, headers=headers, data=file) +# +# data = response.json() +# result = data['results']['channels'][0]['alternatives'][0] +# segments = [] +# for word in result['words']: +# if not segments: +# segments.append({ +# 'speaker': f"SPEAKER_{word['speaker']}", +# 'start': word['start'], +# 'end': word['end'], +# 'text': word['word'], +# 'isUser': False +# }) +# else: +# last_segment = segments[-1] +# if last_segment['speaker'] == f"SPEAKER_{word['speaker']}": +# last_segment['text'] += f" {word['word']}" +# last_segment['end'] = word['end'] +# else: +# segments.append({ +# 'speaker': f"SPEAKER_{word['speaker']}", +# 'start': word['start'], +# 'end': word['end'], +# 'text': word['word'], +# 'isUser': False +# }) +# +# return segments + + +async def send_initial_file(data: List[List[int]], transcript_socket): print('send_initial_file2') start = time.time() # Reading and sending in chunks @@ -88,9 +75,6 @@ async def send_initial_file2(data: List[List[int]], transcript_socket): print('send_initial_file', time.time() - start) -# Add this new function to handle initial file sending - - deepgram = DeepgramClient(os.getenv('DEEPGRAM_API_KEY'), DeepgramClientOptions(options={"keepalive": "true"})) diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index dc45c4495..86ff7529d 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -1,17 +1,10 @@ -# import numpy as np import os from enum import Enum import requests import torch -from utils.endpoints import timeit - -# # Instantiate pretrained voice activity detection pipeline -# vad = Pipeline.from_pretrained( -# "pyannote/voice-activity-detection", -# use_auth_token=os.getenv('HUGGINGFACE_TOKEN') -# ) +from utils.other.endpoints import timeit torch.set_num_threads(1) torch.hub.set_dir('pretrained_models') @@ -33,87 +26,18 @@ def is_speech_present(data, vad_iterator, window_size_samples=256): return True return False - -def voice_in_bytes(data): - # Convert audio bytes to a numpy array - audio_array = np.frombuffer(data, dtype=np.int16) - - # Normalize audio to range [-1, 1] - audio_tensor = torch.from_numpy(audio_array).float() / 32768.0 - - # Ensure the audio is in the correct shape (batch_size, num_channels, num_samples) - audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) - - # Pass the audio tensor to the VAD model - speech_timestamps = get_speech_timestamps(audio_tensor, model) - - # Check if there's voice in the audio - if speech_timestamps: - print("Voice detected in the audio.") - else: - print("No voice detected in the audio.") - - -# -# -# def speech_probabilities(file_path): -# SAMPLING_RATE = 8000 -# vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE) -# wav = read_audio(file_path, sampling_rate=SAMPLING_RATE) -# speech_probs = [] -# window_size_samples = 512 if SAMPLING_RATE == 16000 else 256 -# for i in range(0, len(wav), window_size_samples): -# chunk = wav[i: i + window_size_samples] -# if len(chunk) < window_size_samples: -# break -# speech_prob = model(chunk, SAMPLING_RATE).item() -# speech_probs.append(speech_prob) -# vad_iterator.reset_states() # reset model states after each audio -# print(speech_probs[:10]) # first 10 chunks predicts -# -# @timeit def is_audio_empty(file_path, sample_rate=8000): wav = read_audio(file_path) timestamps = get_speech_timestamps(wav, model, sampling_rate=sample_rate) - # prob_no_speech = len(timestamps) == 1 and timestamps[0].duration < 1 if len(timestamps) == 1: prob_not_speech = ((timestamps[0]['end'] / 1000) - (timestamps[0]['start'] / 1000)) < 1 return prob_not_speech return len(timestamps) == 0 -# -# -# def retrieve_proper_segment_points(file_path, sample_rate=8000): -# wav = read_audio(file_path) -# speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=sample_rate) -# if not speech_timestamps: -# return [None, None] -# return [speech_timestamps[0]['start'] / 1000, speech_timestamps[-1]['end'] / 1000] - -# def retrieve_proper_segment_points_pyannote(file_path): -# output = vad(file_path) -# segments = output.get_timeline().support() -# has_speech = any(segments) -# if not has_speech: -# return [None, None] -# return [segments[0].start, segments[-1].end] - - -# TODO: improve VAD management in someway, mix with pipeline -# TODO: segments[0].duration < 1 makes sense? - -# @timeit -# def is_audio_empty(file_path, sample_rate=8000): -# output = vad(file_path) -# segments = output.get_timeline().support() -# has_speech = any(segments) -# prob_no_speech = len(segments) == 1 and segments[0].duration < 1 -# print('is_audio_empty:', not has_speech or prob_no_speech) -# return not has_speech or prob_no_speech - def vad_is_empty(file_path, return_segments: bool = False): + """Uses vad_modal/vad.py deployment (Best quality)""" try: with open(file_path, 'rb') as file: files = {'file': (file_path.split('/')[-1], file, 'audio/wav')} From 25f30d2938cbbb92284e9e47e60961ce7f6dc02a Mon Sep 17 00:00:00 2001 From: 0xzre Date: Mon, 19 Aug 2024 08:55:59 +0700 Subject: [PATCH 10/13] feat prespeech(500ms), change speechtimeout & threshold & VAD window size --- backend/routers/transcribe.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 416e53c57..8c5ea2c0d 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -51,7 +51,7 @@ async def _websocket_util( websocket_active = True duration = 0 is_speech_active = False - speech_timeout = 1.0 # Good for now (who doesnt like integer) but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. + speech_timeout = 2.0 # Good for now (who doesnt like integer) but better dynamically adjust it by user behaviour, just idea: Increase as active time passes but until certain threshold, but not needed yet. last_speech_time = 0 try: if language == 'en' and codec == 'opus' and include_speech_profile: @@ -74,7 +74,7 @@ async def _websocket_util( await websocket.close() return - threshold = 0.6 + threshold = 0.7 vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 if codec == 'opus': @@ -82,13 +82,14 @@ async def _websocket_util( async def receive_audio(socket1, socket2): nonlocal is_speech_active, last_speech_time, decoder, websocket_active - audio_buffer = deque(maxlen=sample_rate * 1) # 1 secs - databuffer = bytearray(b"") REALTIME_RESOLUTION = 0.01 sample_width = 2 # pcm8/16 here is 16 bit byte_rate = sample_width * sample_rate * channels chunk_size = int(byte_rate * REALTIME_RESOLUTION) + audio_buffer = deque(maxlen=byte_rate * 1) # 1 secs + databuffer = bytearray(b"") + prespeech_audio = deque(maxlen=int(byte_rate * 0.5)) # Queue of audio that will included to data (sent to DG) when is_speech_active become True timer_start = time.time() audio_cursor = 0 # For sleep realtime logic @@ -108,8 +109,11 @@ async def receive_audio(socket1, socket2): if len(audio_buffer) >= window_size_samples: tensor_audio = torch.tensor(list(audio_buffer)) # Good alr, but increase the window size to get wider context but server will be slower - if is_speech_present(tensor_audio[-window_size_samples:], vad_iterator, window_size_samples): + if is_speech_present(tensor_audio[-window_size_samples * 4:], vad_iterator, window_size_samples): if not is_speech_active: + for audio in prespeech_audio: + databuffer.extend(audio.int().numpy().tobytes()) + prespeech_audio.clear() print('+Detected speech') is_speech_active = True last_speech_time = time.time() @@ -119,9 +123,11 @@ async def receive_audio(socket1, socket2): # Reset only happens after the speech timeout # Reason : Better to carry vad context for a speech, then reset for any new speech vad_iterator.reset_states() + prespeech_audio.extend(samples) print('-NO Detected speech') continue else: + prespeech_audio.extend(samples) continue elapsed_seconds = time.time() - timer_start From 52fd3403bf01732bfb323691223fc4de01c620fc Mon Sep 17 00:00:00 2001 From: 0xzre Date: Mon, 19 Aug 2024 09:26:12 +0700 Subject: [PATCH 11/13] more accuracy on speech timeout --- backend/routers/transcribe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 8c5ea2c0d..a7a37ad85 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -96,6 +96,7 @@ async def receive_audio(socket1, socket2): try: while websocket_active: data = await websocket.receive_bytes() + recv_time = time.time() if codec == 'opus': decoded_opus = decoder.decode(data, frame_size=320) samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 @@ -118,7 +119,7 @@ async def receive_audio(socket1, socket2): is_speech_active = True last_speech_time = time.time() elif is_speech_active: - if time.time() - last_speech_time > speech_timeout: + if recv_time - last_speech_time > speech_timeout: is_speech_active = False # Reset only happens after the speech timeout # Reason : Better to carry vad context for a speech, then reset for any new speech From 2f0f3d81e96e035f5b5e499ce8a6091ac7201584 Mon Sep 17 00:00:00 2001 From: 0xzre Date: Mon, 19 Aug 2024 15:13:46 +0700 Subject: [PATCH 12/13] only using VAD on pcm8/16 now --- backend/routers/transcribe.py | 64 ++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index a7a37ad85..7a96c76a5 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -6,7 +6,7 @@ from starlette.websockets import WebSocketState import torch from collections import deque -import opuslib +# import opuslib from database.redis_db import get_user_speech_profile, get_user_speech_profile_duration from utils.stt.streaming import process_audio_dg, send_initial_file @@ -77,11 +77,12 @@ async def _websocket_util( threshold = 0.7 vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold) window_size_samples = 256 if sample_rate == 8000 else 512 - if codec == 'opus': - decoder = opuslib.Decoder(sample_rate, channels) + # if codec == 'opus': + # decoder = opuslib.Decoder(sample_rate, channels) async def receive_audio(socket1, socket2): - nonlocal is_speech_active, last_speech_time, decoder, websocket_active + nonlocal is_speech_active, last_speech_time, websocket_active + # nonlocal decoder REALTIME_RESOLUTION = 0.01 sample_width = 2 # pcm8/16 here is 16 bit @@ -98,43 +99,45 @@ async def receive_audio(socket1, socket2): data = await websocket.receive_bytes() recv_time = time.time() if codec == 'opus': - decoded_opus = decoder.decode(data, frame_size=320) - samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 + # decoded_opus = decoder.decode(data, frame_size=320) + # samples = torch.frombuffer(decoded_opus, dtype=torch.int16).float() / 32768.0 + pass elif codec in ['pcm8', 'pcm16']: # Both are 16 bit writable_data = bytearray(data) samples = torch.frombuffer(writable_data, dtype=torch.int16).float() / 32768.0 else: raise ValueError(f"Unsupported codec: {codec}") - - audio_buffer.extend(samples) - if len(audio_buffer) >= window_size_samples: - tensor_audio = torch.tensor(list(audio_buffer)) - # Good alr, but increase the window size to get wider context but server will be slower - if is_speech_present(tensor_audio[-window_size_samples * 4:], vad_iterator, window_size_samples): - if not is_speech_active: - for audio in prespeech_audio: - databuffer.extend(audio.int().numpy().tobytes()) - prespeech_audio.clear() - print('+Detected speech') - is_speech_active = True - last_speech_time = time.time() - elif is_speech_active: - if recv_time - last_speech_time > speech_timeout: - is_speech_active = False - # Reset only happens after the speech timeout - # Reason : Better to carry vad context for a speech, then reset for any new speech - vad_iterator.reset_states() + # FIXME: opuslib is not working, so we are not using it + if codec != 'opus': + audio_buffer.extend(samples) + if len(audio_buffer) >= window_size_samples: + tensor_audio = torch.tensor(list(audio_buffer)) + # Good alr, but increase the window size to get wider context but server will be slower + if is_speech_present(tensor_audio[-window_size_samples * 4:], vad_iterator, window_size_samples): + if not is_speech_active: + for audio in prespeech_audio: + databuffer.extend(audio.int().numpy().tobytes()) + prespeech_audio.clear() + print('+Detected speech') + is_speech_active = True + last_speech_time = time.time() + elif is_speech_active: + if recv_time - last_speech_time > speech_timeout: + is_speech_active = False + # Reset only happens after the speech timeout + # Reason : Better to carry vad context for a speech, then reset for any new speech + vad_iterator.reset_states() + prespeech_audio.extend(samples) + print('-NO Detected speech') + continue + else: prespeech_audio.extend(samples) - print('-NO Detected speech') continue - else: - prespeech_audio.extend(samples) - continue elapsed_seconds = time.time() - timer_start if elapsed_seconds > duration or not socket2: databuffer.extend(data) - if len(databuffer) >= chunk_size: + if len(databuffer) >= chunk_size or codec == 'opus': # Sleep logic, because naive sleep is not accurate current_time = time.time() elapsed_time = current_time - timer_start @@ -167,7 +170,6 @@ async def send_heartbeat(): try: while websocket_active: await asyncio.sleep(30) - # print('send_heartbeat') if websocket.client_state == WebSocketState.CONNECTED: await websocket.send_json({"type": "ping"}) else: From 4aa433f60c6c6fe52999b4017a0f7639551b9f27 Mon Sep 17 00:00:00 2001 From: 0xzre Date: Mon, 19 Aug 2024 15:43:15 +0700 Subject: [PATCH 13/13] remove sentece log --- backend/utils/stt/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index ad76f0b87..ca9661683 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -88,7 +88,7 @@ async def process_audio_dg( def on_message(self, result, **kwargs): # print(f"Received message from Deepgram") # Log when message is received sentence = result.channel.alternatives[0].transcript - print(sentence) + # print(sentence) if len(sentence) == 0: return # print(sentence)