diff --git a/.gitignore b/.gitignore index e49c2b381..0d1cd6f1e 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,9 @@ app/lib/env/prod_env.g.dart /backend/scripts/research/*.md /backend/scripts/research/*.json /backend/scripts/research/*.csv -/backend/scripts/research/users \ No newline at end of file +/backend/scripts/research/users +/backend/scripts/stt/_temp +/backend/scripts/stt/_temp2 +/backend/scripts/stt/pretrained_models +/backend/scripts/stt/results +/backend/scripts/stt/diarization.json diff --git a/app/lib/backend/auth.dart b/app/lib/backend/auth.dart index d1ed00dab..837b97df2 100644 --- a/app/lib/backend/auth.dart +++ b/app/lib/backend/auth.dart @@ -110,7 +110,7 @@ Future signInWithGoogle() async { try { print('Signing in with Google'); // Trigger the authentication flow - final GoogleSignInAccount? googleUser = await GoogleSignIn().signIn(); + final GoogleSignInAccount? googleUser = await GoogleSignIn(scopes: ['profile', 'email']).signIn(); print('Google User: $googleUser'); // Obtain the auth details from the request final GoogleSignInAuthentication? googleAuth = await googleUser?.authentication; diff --git a/app/lib/backend/http/api/memories.dart b/app/lib/backend/http/api/memories.dart index d70fcc05b..3ca3eb8b9 100644 --- a/app/lib/backend/http/api/memories.dart +++ b/app/lib/backend/http/api/memories.dart @@ -188,11 +188,13 @@ class TranscriptsResponse { List deepgram; List soniox; List whisperx; + List speechmatics; TranscriptsResponse({ this.deepgram = const [], this.soniox = const [], this.whisperx = const [], + this.speechmatics = const [], }); factory TranscriptsResponse.fromJson(Map json) { @@ -200,6 +202,8 @@ class TranscriptsResponse { deepgram: (json['deepgram'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), soniox: (json['soniox'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), whisperx: (json['whisperx'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), + speechmatics: + (json['speechmatics'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), ); } } diff --git a/app/lib/backend/http/api/messages.dart b/app/lib/backend/http/api/messages.dart index ea33c8362..5eb6f2747 100644 --- a/app/lib/backend/http/api/messages.dart +++ b/app/lib/backend/http/api/messages.dart @@ -23,6 +23,16 @@ Future> getMessagesServer() async { return []; } +Future> clearChatServer() async { + var response = await makeApiCall(url: '${Env.apiBaseUrl}v1/clear-chat', headers: {}, method: 'DELETE', body: ''); + if (response == null) throw Exception('Failed to delete chat'); + if (response.statusCode == 200) { + return [ServerMessage.fromJson(jsonDecode(response.body))]; + } else { + throw Exception('Failed to delete chat'); + } +} + Future sendMessageServer(String text, {String? pluginId}) { return makeApiCall( url: '${Env.apiBaseUrl}v1/messages?plugin_id=$pluginId', diff --git a/app/lib/backend/preferences.dart b/app/lib/backend/preferences.dart index e31dd9fcc..373002b49 100644 --- a/app/lib/backend/preferences.dart +++ b/app/lib/backend/preferences.dart @@ -47,6 +47,8 @@ class SharedPreferencesUtil { set deviceCodec(BleAudioCodec value) => saveString('deviceCodec', mapCodecToName(value)); + Future setDeviceCodec(BleAudioCodec value) => saveString('deviceCodec', mapCodecToName(value)); + BleAudioCodec get deviceCodec => mapNameToCodec(getString('deviceCodec') ?? ''); String get openAIApiKey => getString('openaiApiKey') ?? ''; diff --git a/app/lib/backend/schema/bt_device.dart b/app/lib/backend/schema/bt_device.dart index b14fb7f7e..999518f0a 100644 --- a/app/lib/backend/schema/bt_device.dart +++ b/app/lib/backend/schema/bt_device.dart @@ -3,7 +3,17 @@ import 'package:friend_private/services/device_connections.dart'; import 'package:friend_private/services/frame_connection.dart'; import 'package:friend_private/utils/ble/gatt_utils.dart'; -enum BleAudioCodec { pcm16, pcm8, mulaw16, mulaw8, opus, unknown } +enum BleAudioCodec { + pcm16, + pcm8, + mulaw16, + mulaw8, + opus, + unknown; + + @override + String toString() => mapCodecToName(this); +} String mapCodecToName(BleAudioCodec codec) { switch (codec) { diff --git a/app/lib/backend/schema/transcript_segment.dart b/app/lib/backend/schema/transcript_segment.dart index 47f4b1398..3c4eff301 100644 --- a/app/lib/backend/schema/transcript_segment.dart +++ b/app/lib/backend/schema/transcript_segment.dart @@ -129,6 +129,26 @@ class TranscriptSegment { cleanSegments(joinedSimilarSegments); segments.addAll(joinedSimilarSegments); + + // for i, segment in enumerate(segments): + // segments[i].text = ( + // segments[i].text.strip() + // .replace(' ', '') + // .replace(' ,', ',') + // .replace(' .', '.') + // .replace(' ?', '?') + // ) + + // Speechmatics specific issue with punctuation + for (var i = 0; i < segments.length; i++) { + segments[i].text = segments[i] + .text + .replaceAll(' ', '') + .replaceAll(' ,', ',') + .replaceAll(' .', '.') + .replaceAll(' ?', '?') + .trim(); + } } static String segmentsAsString( diff --git a/app/lib/env/env.dart b/app/lib/env/env.dart index d68f2fee4..8a14a8046 100644 --- a/app/lib/env/env.dart +++ b/app/lib/env/env.dart @@ -13,10 +13,10 @@ abstract class Env { static String? get mixpanelProjectToken => _instance.mixpanelProjectToken; - // static String? get apiBaseUrl => _instance.apiBaseUrl; + static String? get apiBaseUrl => _instance.apiBaseUrl; // static String? get apiBaseUrl => 'https://based-hardware-development--backened-dev-api.modal.run/'; - static String? get apiBaseUrl => 'https://camel-lucky-reliably.ngrok-free.app/'; + // static String? get apiBaseUrl => 'https://camel-lucky-reliably.ngrok-free.app/'; // static String? get apiBaseUrl => 'https://mutual-fun-boar.ngrok-free.app/'; static String? get growthbookApiKey => _instance.growthbookApiKey; diff --git a/app/lib/main.dart b/app/lib/main.dart index e55d2bd69..6d0f924c6 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -164,16 +164,15 @@ class _MyAppState extends State with WidgetsBindingObserver { update: (BuildContext context, value, MessageProvider? previous) => (previous?..updatePluginProvider(value)) ?? MessageProvider(), ), - ChangeNotifierProvider(create: (context) => WebSocketProvider()), - ChangeNotifierProxyProvider3( + ChangeNotifierProxyProvider2( create: (context) => CaptureProvider(), - update: (BuildContext context, memory, message, wsProvider, CaptureProvider? previous) => - (previous?..updateProviderInstances(memory, message, wsProvider)) ?? CaptureProvider(), + update: (BuildContext context, memory, message, CaptureProvider? previous) => + (previous?..updateProviderInstances(memory, message)) ?? CaptureProvider(), ), - ChangeNotifierProxyProvider2( + ChangeNotifierProxyProvider( create: (context) => DeviceProvider(), - update: (BuildContext context, captureProvider, wsProvider, DeviceProvider? previous) => - (previous?..setProviders(captureProvider, wsProvider)) ?? DeviceProvider(), + update: (BuildContext context, captureProvider, DeviceProvider? previous) => + (previous?..setProviders(captureProvider)) ?? DeviceProvider(), ), ChangeNotifierProxyProvider( create: (context) => OnboardingProvider(), @@ -181,10 +180,10 @@ class _MyAppState extends State with WidgetsBindingObserver { (previous?..setDeviceProvider(value)) ?? OnboardingProvider(), ), ListenableProvider(create: (context) => HomeProvider()), - ChangeNotifierProxyProvider3( + ChangeNotifierProxyProvider( create: (context) => SpeechProfileProvider(), - update: (BuildContext context, device, capture, wsProvider, SpeechProfileProvider? previous) => - (previous?..setProviders(device, capture, wsProvider)) ?? SpeechProfileProvider(), + update: (BuildContext context, device, SpeechProfileProvider? previous) => + (previous?..setProviders(device)) ?? SpeechProfileProvider(), ), ChangeNotifierProxyProvider2( create: (context) => MemoryDetailProvider(), @@ -275,6 +274,11 @@ class _DeciderWidgetState extends State { if (context.read().isConnected) { NotificationService.instance.saveNotificationToken(); } + + if (context.read().user != null) { + context.read().setMessagesFromCache(); + context.read().refreshMessages(); + } }); super.initState(); } diff --git a/app/lib/pages/capture/_page.dart b/app/lib/pages/capture/_page.dart index b392425d0..75d10b227 100644 --- a/app/lib/pages/capture/_page.dart +++ b/app/lib/pages/capture/_page.dart @@ -1,22 +1,6 @@ import 'package:flutter/material.dart'; -import 'package:flutter/scheduler.dart'; -import 'package:flutter_foreground_task/flutter_foreground_task.dart'; -import 'package:flutter_provider_utilities/flutter_provider_utilities.dart'; -import 'package:friend_private/backend/schema/bt_device.dart'; -import 'package:friend_private/backend/schema/geolocation.dart'; -import 'package:friend_private/pages/capture/widgets/widgets.dart'; -import 'package:friend_private/providers/capture_provider.dart'; -import 'package:friend_private/providers/connectivity_provider.dart'; -import 'package:friend_private/providers/device_provider.dart'; -import 'package:friend_private/providers/onboarding_provider.dart'; -import 'package:friend_private/utils/audio/wav_bytes.dart'; -import 'package:friend_private/utils/ble/communication.dart'; -import 'package:friend_private/utils/enums.dart'; -import 'package:friend_private/widgets/dialog.dart'; -import 'package:provider/provider.dart'; - -import '../../providers/websocket_provider.dart'; +@Deprecated("Capture page is deprecated, use @pages > memories > widgets > capture instead.") class CapturePage extends StatefulWidget { const CapturePage({ super.key, @@ -26,252 +10,9 @@ class CapturePage extends StatefulWidget { State createState() => CapturePageState(); } -class CapturePageState extends State with AutomaticKeepAliveClientMixin, WidgetsBindingObserver { - @override - bool get wantKeepAlive => true; - - /// ---- - - // List segments = List.filled(100, '') - // .mapIndexed((i, e) => TranscriptSegment( - // text: - // '''[00:00:00 - 00:02:23] Speaker 0: The tech giants already know these techniques. - // My goal is to unlock their secrets for the benefit of businesses who to design and help users develop healthy habits. - // To that end, there's so much I wanted to put in this book that just didn't fit. Before you reading, please take a moment to download these - // supplementary materials included free with the purchase of this audiobook. Please go to nirandfar.com forward slash hooked. - // Near is spelled like my first name, speck, n I r. Andfar.com/hooked. There you will find the hooked model workbook, an ebook of case studies, - // and a free email course about product psychology. Also, if you'd like to connect with me, you can reach me through my blog at nirafar.com. - // You can schedule office hours to discuss your questions. Look forward to hearing from you as you build habits for good. - // - // Introduction. 79% of smartphone owners check their device within 15 minutes of waking up every morning. Perhaps most startling, - // fully 1 third of Americans say they would rather give up sex than lose their cell phones. A 2011 university study suggested people check their - // phones 34 times per day. However, industry insiders believe that number is closer to an astounding 150 daily sessions. We are hooked. - // It's the poll to visit YouTube, Facebook, or Twitter for just a few minutes only to find yourself still capping and scrolling an hour later. - // It's the urge you likely feel throughout your day but hardly notice. Cognitive psychologists define habits as, quote, automatic behaviors triggered - // by situational cues. Things we do with little or no conscious thought. The products and services we use habitually alter our everyday behavior. - // Just as their designers intended. Our actions have been engineered. How do companies producing little more than bits of code displayed on a screen - // seemingly control users' minds? What makes some products so habit forming? Forming habit is imperative for the survival of many products. - // - // As infinite distractions compete for our attention, companies are learning to master novel tactics that stay relevant in users' minds. - // Amassing millions of users is no longer good enough. Companies increasingly find that their economic value is a function of the strength of the habits they create. - // - // In order to win the loyalty of their users and create a product that's regularly used, companies must learn not only what compels users to click, - // but also what makes them click. Although some companies are just waking up to this new reality, others are already cashing in. By mastering habit - // forming product design, companies profiles in this book make their goods indispensable. First to mind wins. Companies that form strong user habits enjoy - // several benefits to their bottom line. These companies attach their product to internal triggers. A result, users show up without any external prompting. - // Instead of relying on expensive marketing, how did forming companies link their services to users' daily routines and emotions. - // A habit is at work when users feel a tad bored and instantly open Twitter. Feel a hang of loneliness, and before rational thought occurs, - // they're scrolling through their Facebook feeds.''', - // speaker: 'SPEAKER_0${i % 2}', - // isUser: false, - // start: 0, - // end: 10, - // )) - // .toList(); - - setHasTranscripts(bool hasTranscripts) { - context.read().setHasTranscripts(hasTranscripts); - } - - void _onReceiveTaskData(dynamic data) { - if (data is Map) { - if (data.containsKey('latitude') && data.containsKey('longitude')) { - context.read().setGeolocation(Geolocation( - latitude: data['latitude'], - longitude: data['longitude'], - accuracy: data['accuracy'], - altitude: data['altitude'], - time: DateTime.parse(data['time']), - )); - } else { - if (mounted) { - context.read().setGeolocation(null); - } - } - } - } - - @override - void initState() { - WavBytesUtil.clearTempWavFiles(); - - FlutterForegroundTask.addTaskDataCallback(_onReceiveTaskData); - WidgetsBinding.instance.addObserver(this); - SchedulerBinding.instance.addPostFrameCallback((_) async { - // await context.read().processCachedTranscript(); - if (context.read().connectedDevice != null) { - context.read().stopFindDeviceTimer(); - } - // if (await LocationService().displayPermissionsDialog()) { - // await showDialog( - // context: context, - // builder: (c) => getDialog( - // context, - // () => Navigator.of(context).pop(), - // () async { - // await requestLocationPermission(); - // await LocationService().requestBackgroundPermission(); - // if (mounted) Navigator.of(context).pop(); - // }, - // 'Enable Location? 🌍', - // 'Allow location access to tag your memories. Set to "Always Allow" in Settings', - // singleButton: false, - // okButtonText: 'Continue', - // ), - // ); - // } - final connectivityProvider = Provider.of(context, listen: false); - if (!connectivityProvider.isConnected) { - context.read().cancelMemoryCreationTimer(); - } - }); - - super.initState(); - } - - @override - void dispose() { - WidgetsBinding.instance.removeObserver(this); - // context.read().closeWebSocket(); - super.dispose(); - } - - // Future requestLocationPermission() async { - // LocationService locationService = LocationService(); - // bool serviceEnabled = await locationService.enableService(); - // if (!serviceEnabled) { - // debugPrint('Location service not enabled'); - // if (mounted) { - // ScaffoldMessenger.of(context).showSnackBar( - // const SnackBar( - // content: Text( - // 'Location services are disabled. Enable them for a better experience.', - // style: TextStyle(color: Colors.white, fontSize: 14), - // ), - // ), - // ); - // } - // } else { - // PermissionStatus permissionGranted = await locationService.requestPermission(); - // SharedPreferencesUtil().locationEnabled = permissionGranted == PermissionStatus.granted; - // MixpanelManager().setUserProperty('Location Enabled', SharedPreferencesUtil().locationEnabled); - // if (permissionGranted == PermissionStatus.denied) { - // debugPrint('Location permission not granted'); - // } else if (permissionGranted == PermissionStatus.deniedForever) { - // debugPrint('Location permission denied forever'); - // if (mounted) { - // ScaffoldMessenger.of(context).showSnackBar( - // const SnackBar( - // content: Text( - // 'If you change your mind, you can enable location services in your device settings.', - // style: TextStyle(color: Colors.white, fontSize: 14), - // ), - // ), - // ); - // } - // } - // } - // } - +class CapturePageState extends State { @override Widget build(BuildContext context) { - super.build(context); - return Consumer2(builder: (context, provider, deviceProvider, child) { - return MessageListener( - showInfo: (info) { - // This probably will never be called because this has been handled even before we start the audio stream. But it's here just in case. - if (info == 'FIM_CHANGE') { - showDialog( - context: context, - barrierDismissible: false, - builder: (c) => getDialog( - context, - () async { - context.read().closeWebSocketWithoutReconnect('Firmware change detected'); - var connectedDevice = deviceProvider.connectedDevice; - var codec = await getAudioCodec(connectedDevice!.id); - context.read().resetState(restartBytesProcessing: true); - context.read().initiateWebsocket(codec); - if (Navigator.canPop(context)) { - Navigator.pop(context); - } - }, - () => {}, - 'Firmware change detected!', - 'You are currently using a different firmware version than the one you were using before. Please restart the app to apply the changes.', - singleButton: true, - okButtonText: 'Restart', - ), - ); - } - }, - showError: (error) { - ScaffoldMessenger.of(context).showSnackBar( - SnackBar( - content: Text( - error, - style: const TextStyle(color: Colors.white, fontSize: 14), - ), - ), - ); - }, - child: Stack( - children: [ - ListView(children: [ - SpeechProfileCardWidget(), - ...getConnectionStateWidgets( - context, - provider.hasTranscripts, - deviceProvider.connectedDevice, - context.read().wsConnectionState, - ), - getTranscriptWidget( - provider.memoryCreating, - provider.segments, - provider.photos, - deviceProvider.connectedDevice, - ), - ...connectionStatusWidgets( - context, - provider.segments, - context.read().wsConnectionState, - ), - const SizedBox(height: 16) - ]), - getPhoneMicRecordingButton(() => _recordingToggled(provider), provider.recordingState), - ], - ), - ); - }); - } - - _recordingToggled(CaptureProvider provider) async { - var recordingState = provider.recordingState; - if (recordingState == RecordingState.record) { - provider.stopStreamRecording(); - provider.updateRecordingState(RecordingState.stop); - context.read().cancelMemoryCreationTimer(); - // await context.read().tryCreateMemoryManually(); - } else if (recordingState == RecordingState.initialising) { - debugPrint('initialising, have to wait'); - } else { - showDialog( - context: context, - builder: (c) => getDialog( - context, - () => Navigator.pop(context), - () async { - provider.updateRecordingState(RecordingState.initialising); - context.read().closeWebSocketWithoutReconnect('Recording with phone mic'); - await provider.initiateWebsocket(BleAudioCodec.pcm16, 16000); - await provider.streamRecording(); - Navigator.pop(context); - }, - 'Limited Capabilities', - 'Recording with your phone microphone has a few limitations, including but not limited to: speaker profiles, background reliability.', - okButtonText: 'Ok, I understand', - ), - ); - } + return const Text("Depreacted"); } } diff --git a/app/lib/pages/capture/widgets/widgets.dart b/app/lib/pages/capture/widgets/widgets.dart index ce16dbd54..d131f5ae7 100644 --- a/app/lib/pages/capture/widgets/widgets.dart +++ b/app/lib/pages/capture/widgets/widgets.dart @@ -4,6 +4,7 @@ import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/pages/capture/connect.dart'; import 'package:friend_private/pages/speech_profile/page.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; import 'package:friend_private/providers/home_provider.dart'; @@ -227,8 +228,7 @@ class SpeechProfileCardWidget extends StatelessWidget { await routeToPage(context, const SpeechProfilePage()); if (hasSpeakerProfile != SharedPreferencesUtil().hasSpeakerProfile) { if (context.mounted) { - // TODO: is the websocket restarting once the user comes back? - context.read().restartWebSocket(); + context.read().onRecordProfileSettingChanged(); } } }, @@ -290,7 +290,7 @@ getTranscriptWidget( return Column( children: [ - if (photos.isNotEmpty) PhotosGridComponent(photos: photos), + if (photos.isNotEmpty) PhotosGridComponent(), if (segments.isNotEmpty) TranscriptWidget(segments: segments), ], ); diff --git a/app/lib/pages/chat/page.dart b/app/lib/pages/chat/page.dart index 223aca73a..5608c3469 100644 --- a/app/lib/pages/chat/page.dart +++ b/app/lib/pages/chat/page.dart @@ -2,6 +2,7 @@ import 'dart:io'; import 'package:collection/collection.dart'; import 'package:flutter/material.dart'; +import 'package:flutter/rendering.dart'; import 'package:flutter/scheduler.dart'; import 'package:friend_private/backend/http/api/messages.dart'; import 'package:friend_private/backend/preferences.dart'; @@ -9,6 +10,7 @@ import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/backend/schema/message.dart'; import 'package:friend_private/backend/schema/plugin.dart'; import 'package:friend_private/pages/chat/widgets/ai_message.dart'; +import 'package:friend_private/pages/chat/widgets/animated_mini_banner.dart'; import 'package:friend_private/pages/chat/widgets/user_message.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/home_provider.dart'; @@ -29,7 +31,10 @@ class ChatPage extends StatefulWidget { class ChatPageState extends State with AutomaticKeepAliveClientMixin { TextEditingController textController = TextEditingController(); - ScrollController scrollController = ScrollController(); + late ScrollController scrollController; + + bool _showDeleteOption = false; + bool isScrollingDown = false; var prefs = SharedPreferencesUtil(); late List plugins; @@ -49,11 +54,28 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { @override void initState() { plugins = prefs.pluginsList; + scrollController = ScrollController(); + scrollController.addListener(() { + if (scrollController.position.userScrollDirection == ScrollDirection.reverse) { + if (!isScrollingDown) { + isScrollingDown = true; + _showDeleteOption = true; + setState(() {}); + } + } + + if (scrollController.position.userScrollDirection == ScrollDirection.forward) { + if (isScrollingDown) { + isScrollingDown = false; + _showDeleteOption = false; + setState(() {}); + } + } + }); SchedulerBinding.instance.addPostFrameCallback((_) async { - await context.read().refreshMessages(); scrollToBottom(); }); - // _initDailySummary(); + ; super.initState(); } @@ -70,134 +92,216 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { print('ChatPage build'); return Consumer2( builder: (context, provider, connectivityProvider, child) { - return Stack( - children: [ - Align( - alignment: Alignment.topCenter, - child: provider.isLoadingMessages - ? const Padding( - padding: EdgeInsets.only(top: 32.0), - child: CircularProgressIndicator( - color: Colors.white, + return Scaffold( + backgroundColor: Theme.of(context).colorScheme.primary, + appBar: provider.isLoadingMessages + ? AnimatedMiniBanner( + showAppBar: provider.isLoadingMessages, + child: Container( + width: double.infinity, + height: 10, + color: Colors.green, + child: const Center( + child: Text( + 'Syncing messages with server...', + style: TextStyle(color: Colors.white, fontSize: 14), ), - ) - : (provider.messages.isEmpty) - ? Text( - connectivityProvider.isConnected - ? 'No messages yet!\nWhy don\'t you start a conversation?' - : 'Please check your internet connection and try again', - textAlign: TextAlign.center, - style: const TextStyle(color: Colors.white)) - : ListView.builder( - shrinkWrap: true, - reverse: true, - controller: scrollController, - // physics: const NeverScrollableScrollPhysics(), - itemCount: provider.messages.length, - itemBuilder: (context, chatIndex) { - final message = provider.messages[chatIndex]; - double topPadding = chatIndex == provider.messages.length - 1 ? 24 : 16; - double bottomPadding = chatIndex == 0 - ? Platform.isAndroid - ? 200 - : 170 - : 0; - return Padding( - key: ValueKey(message.id), - padding: EdgeInsets.only(bottom: bottomPadding, left: 18, right: 18, top: topPadding), - child: message.sender == MessageSender.ai - ? AIMessage( - message: message, - sendMessage: _sendMessageUtil, - displayOptions: provider.messages.length <= 1, - pluginSender: plugins.firstWhereOrNull((e) => e.id == message.pluginId), - updateMemory: (ServerMemory memory) { - context.read().updateMemory(memory); - }, - ) - : HumanMessage(message: message), - ); + ), + ), + ) + : AnimatedMiniBanner( + showAppBar: _showDeleteOption, + height: 80, + child: Container( + width: double.infinity, + height: 40, + color: Theme.of(context).primaryColor, + child: Row( + children: [ + const SizedBox(width: 20), + InkWell( + onTap: () async { + await context.read().refreshMessages(); }, + child: const Text( + 'Refresh Chat', + style: TextStyle(color: Colors.white, fontSize: 14), + ), ), - ), - Consumer(builder: (context, home, child) { - return Align( - alignment: Alignment.bottomCenter, - child: Container( - width: double.maxFinite, - padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 2), - margin: EdgeInsets.only(left: 32, right: 32, bottom: home.isChatFieldFocused ? 40 : 120), - decoration: const BoxDecoration( - color: Colors.black, - borderRadius: BorderRadius.all(Radius.circular(16)), - border: GradientBoxBorder( - gradient: LinearGradient(colors: [ - Color.fromARGB(127, 208, 208, 208), - Color.fromARGB(127, 188, 99, 121), - Color.fromARGB(127, 86, 101, 182), - Color.fromARGB(127, 126, 190, 236) - ]), - width: 1, + const Spacer(), + InkWell( + onTap: () async { + setState(() { + _showDeleteOption = false; + }); + await context.read().clearChat(); + }, + child: const Text( + 'Clear Chat', + style: TextStyle(color: Colors.white, fontSize: 14), + ), + ), + const SizedBox(width: 20), + ], ), - shape: BoxShape.rectangle, ), - child: TextField( - enabled: true, - controller: textController, - // textCapitalization: TextCapitalization.sentences, - obscureText: false, - focusNode: home.chatFieldFocusNode, - // canRequestFocus: true, - textAlign: TextAlign.start, - textAlignVertical: TextAlignVertical.center, - decoration: InputDecoration( - hintText: 'Ask your Friend anything', - hintStyle: const TextStyle(fontSize: 14.0, color: Colors.grey), - focusedBorder: InputBorder.none, - enabledBorder: InputBorder.none, - suffixIcon: IconButton( - splashColor: Colors.transparent, - splashRadius: 1, - onPressed: loading - ? null - : () async { - String message = textController.text; - if (message.isEmpty) return; - if (connectivityProvider.isConnected) { - _sendMessageUtil(message); - } else { - ScaffoldMessenger.of(context).showSnackBar( - const SnackBar( - content: Text('Please check your internet connection and try again'), - duration: Duration(seconds: 2), - ), - ); - } - }, - icon: loading - ? const SizedBox( - width: 16, - height: 16, - child: CircularProgressIndicator( - valueColor: AlwaysStoppedAnimation(Colors.white), + ), + body: Stack( + children: [ + Align( + alignment: Alignment.topCenter, + child: provider.isLoadingMessages && !provider.hasCachedMessages + ? Column( + children: [ + const SizedBox(height: 100), + const CircularProgressIndicator( + valueColor: AlwaysStoppedAnimation(Colors.white), + ), + const SizedBox(height: 16), + Text( + provider.firstTimeLoadingText, + style: const TextStyle(color: Colors.white), + ), + ], + ) + : provider.isClearingChat + ? const Column( + children: [ + SizedBox(height: 100), + CircularProgressIndicator( + valueColor: AlwaysStoppedAnimation(Colors.white), + ), + SizedBox(height: 16), + Text( + "Deleting your messages from Omi's memory...", + style: TextStyle(color: Colors.white), + ), + ], + ) + : (provider.messages.isEmpty) + ? Center( + child: Padding( + padding: const EdgeInsets.only(bottom: 32.0), + child: Text( + connectivityProvider.isConnected + ? 'No messages yet!\nWhy don\'t you start a conversation?' + : 'Please check your internet connection and try again', + textAlign: TextAlign.center, + style: const TextStyle(color: Colors.white)), ), ) - : const Icon( - Icons.send_rounded, - color: Color(0xFFF7F4F4), - size: 24.0, + : ListView.builder( + shrinkWrap: true, + reverse: true, + controller: scrollController, + // physics: const NeverScrollableScrollPhysics(), + itemCount: provider.messages.length, + itemBuilder: (context, chatIndex) { + final message = provider.messages[chatIndex]; + double topPadding = chatIndex == provider.messages.length - 1 ? 24 : 16; + double bottomPadding = chatIndex == 0 + ? Platform.isAndroid + ? 200 + : 170 + : 0; + return Padding( + key: ValueKey(message.id), + padding: + EdgeInsets.only(bottom: bottomPadding, left: 18, right: 18, top: topPadding), + child: message.sender == MessageSender.ai + ? AIMessage( + message: message, + sendMessage: _sendMessageUtil, + displayOptions: provider.messages.length <= 1, + pluginSender: plugins.firstWhereOrNull((e) => e.id == message.pluginId), + updateMemory: (ServerMemory memory) { + context.read().updateMemory(memory); + }, + ) + : HumanMessage(message: message), + ); + }, ), + ), + Consumer(builder: (context, home, child) { + return Align( + alignment: Alignment.bottomCenter, + child: Container( + width: double.maxFinite, + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 2), + margin: EdgeInsets.only(left: 32, right: 32, bottom: home.isChatFieldFocused ? 40 : 120), + decoration: const BoxDecoration( + color: Colors.black, + borderRadius: BorderRadius.all(Radius.circular(16)), + border: GradientBoxBorder( + gradient: LinearGradient(colors: [ + Color.fromARGB(127, 208, 208, 208), + Color.fromARGB(127, 188, 99, 121), + Color.fromARGB(127, 86, 101, 182), + Color.fromARGB(127, 126, 190, 236) + ]), + width: 1, ), + shape: BoxShape.rectangle, + ), + child: TextField( + enabled: true, + controller: textController, + // textCapitalization: TextCapitalization.sentences, + obscureText: false, + focusNode: home.chatFieldFocusNode, + // canRequestFocus: true, + textAlign: TextAlign.start, + textAlignVertical: TextAlignVertical.center, + decoration: InputDecoration( + hintText: 'Ask your Friend anything', + hintStyle: const TextStyle(fontSize: 14.0, color: Colors.grey), + focusedBorder: InputBorder.none, + enabledBorder: InputBorder.none, + suffixIcon: IconButton( + splashColor: Colors.transparent, + splashRadius: 1, + onPressed: loading + ? null + : () async { + String message = textController.text; + if (message.isEmpty) return; + if (connectivityProvider.isConnected) { + _sendMessageUtil(message); + } else { + ScaffoldMessenger.of(context).showSnackBar( + const SnackBar( + content: Text('Please check your internet connection and try again'), + duration: Duration(seconds: 2), + ), + ); + } + }, + icon: loading + ? const SizedBox( + width: 16, + height: 16, + child: CircularProgressIndicator( + valueColor: AlwaysStoppedAnimation(Colors.white), + ), + ) + : const Icon( + Icons.send_rounded, + color: Color(0xFFF7F4F4), + size: 24.0, + ), + ), + ), + // maxLines: 8, + // minLines: 1, + // keyboardType: TextInputType.multiline, + style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), ), - // maxLines: 8, - // minLines: 1, - // keyboardType: TextInputType.multiline, - style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), ), - ), - ); - }), - ], + ); + }), + ], + ), ); }, ); @@ -223,7 +327,9 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { changeLoadingState(); scrollToBottom(); ServerMessage message = await getInitialPluginMessage(plugin?.id); - context.read().addMessage(message); + if (mounted) { + context.read().addMessage(message); + } scrollToBottom(); changeLoadingState(); } diff --git a/app/lib/pages/chat/widgets/ai_message.dart b/app/lib/pages/chat/widgets/ai_message.dart index 3afac1cdf..95159218a 100644 --- a/app/lib/pages/chat/widgets/ai_message.dart +++ b/app/lib/pages/chat/widgets/ai_message.dart @@ -9,7 +9,9 @@ import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/backend/schema/message.dart'; import 'package:friend_private/backend/schema/plugin.dart'; +import 'package:friend_private/pages/memory_detail/memory_detail_provider.dart'; import 'package:friend_private/pages/memory_detail/page.dart'; +import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -116,8 +118,8 @@ class _AIMessageState extends State { style: TextStyle(fontSize: 15.0, fontWeight: FontWeight.w500, color: Colors.grey.shade300), )), if (widget.message.id != 1) _getCopyButton(context), // RESTORE ME - // if (message.id == 1 && displayOptions) const SizedBox(height: 8), - // if (message.id == 1 && displayOptions) ..._getInitialOptions(context), + if (widget.displayOptions) const SizedBox(height: 8), + if (widget.displayOptions) ..._getInitialOptions(context), if (messageMemories.isNotEmpty) ...[ const SizedBox(height: 16), for (var data in messageMemories.indexed) ...[ @@ -127,30 +129,58 @@ class _AIMessageState extends State { onTap: () async { final connectivityProvider = Provider.of(context, listen: false); if (connectivityProvider.isConnected) { - if (memoryDetailLoading[data.$1]) return; - setState(() => memoryDetailLoading[data.$1] = true); + var memProvider = Provider.of(context, listen: false); + var idx = memProvider.memoriesWithDates.indexWhere((e) { + if (e.runtimeType == ServerMemory) { + return e.id == data.$2.id; + } + return false; + }); - ServerMemory? m = await getMemoryById(data.$2.id); - if (m == null) return; - MixpanelManager().chatMessageMemoryClicked(m); - setState(() => memoryDetailLoading[data.$1] = false); - await Navigator.of(context) - .push(MaterialPageRoute(builder: (c) => MemoryDetailPage(memory: m))); - if (SharedPreferencesUtil().modifiedMemoryDetails?.id == m.id) { - ServerMemory modifiedDetails = SharedPreferencesUtil().modifiedMemoryDetails!; - widget.updateMemory(SharedPreferencesUtil().modifiedMemoryDetails!); - var copy = List.from(widget.message.memories); - copy[data.$1] = MessageMemory( - modifiedDetails.id, - modifiedDetails.createdAt, - MessageMemoryStructured( - modifiedDetails.structured.title, - modifiedDetails.structured.emoji, - )); - widget.message.memories.clear(); - widget.message.memories.addAll(copy); - SharedPreferencesUtil().modifiedMemoryDetails = null; - setState(() {}); + if (idx != -1) { + context.read().updateMemory(idx); + var m = memProvider.memoriesWithDates[idx]; + MixpanelManager().chatMessageMemoryClicked(m); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (c) => MemoryDetailPage( + memory: m, + ), + ), + ); + } else { + if (memoryDetailLoading[data.$1]) return; + setState(() => memoryDetailLoading[data.$1] = true); + ServerMemory? m = await getMemoryById(data.$2.id); + if (m == null) return; + idx = memProvider.addMemoryWithDate(m); + MixpanelManager().chatMessageMemoryClicked(m); + setState(() => memoryDetailLoading[data.$1] = false); + context.read().updateMemory(idx); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (c) => MemoryDetailPage( + memory: m, + ), + ), + ); + //TODO: Not needed anymore I guess because memories are stored in provider and read from there only + if (SharedPreferencesUtil().modifiedMemoryDetails?.id == m.id) { + ServerMemory modifiedDetails = SharedPreferencesUtil().modifiedMemoryDetails!; + widget.updateMemory(SharedPreferencesUtil().modifiedMemoryDetails!); + var copy = List.from(widget.message.memories); + copy[data.$1] = MessageMemory( + modifiedDetails.id, + modifiedDetails.createdAt, + MessageMemoryStructured( + modifiedDetails.structured.title, + modifiedDetails.structured.emoji, + )); + widget.message.memories.clear(); + widget.message.memories.addAll(copy); + SharedPreferencesUtil().modifiedMemoryDetails = null; + setState(() {}); + } } } else { ScaffoldMessenger.of(context).showSnackBar( @@ -256,7 +286,7 @@ class _AIMessageState extends State { _getInitialOption(BuildContext context, String optionText) { return GestureDetector( child: Container( - padding: const EdgeInsets.symmetric(horizontal: 12.0, vertical: 8), + padding: const EdgeInsets.symmetric(horizontal: 12.0, vertical: 10), width: double.maxFinite, decoration: BoxDecoration( color: Colors.grey.shade900, @@ -273,11 +303,11 @@ class _AIMessageState extends State { _getInitialOptions(BuildContext context) { return [ const SizedBox(height: 8), - _getInitialOption(context, 'What tasks do I have from yesterday?'), + _getInitialOption(context, 'What\'s been on my mind a lot?'), const SizedBox(height: 8), - _getInitialOption(context, 'What conversations did I have with John?'), + _getInitialOption(context, 'Did I forget to follow up on something?'), const SizedBox(height: 8), - _getInitialOption(context, 'What advise have I received about entrepreneurship?'), + _getInitialOption(context, 'What\'s the funniest thing I\'ve said lately?'), ]; } } diff --git a/app/lib/pages/chat/widgets/animated_mini_banner.dart b/app/lib/pages/chat/widgets/animated_mini_banner.dart new file mode 100644 index 000000000..b7289f3b1 --- /dev/null +++ b/app/lib/pages/chat/widgets/animated_mini_banner.dart @@ -0,0 +1,21 @@ +import 'package:flutter/material.dart'; + +class AnimatedMiniBanner extends StatelessWidget implements PreferredSizeWidget { + const AnimatedMiniBanner({super.key, required this.showAppBar, required this.child, this.height = 30}); + + final bool showAppBar; + final Widget child; + final double height; + + @override + Widget build(BuildContext context) { + return AnimatedContainer( + height: showAppBar ? kToolbarHeight : 0, + duration: const Duration(milliseconds: 300), + child: child, + ); + } + + @override + Size get preferredSize => Size.fromHeight(height); +} diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index 8ebc167d6..dde45dcb6 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -13,6 +13,7 @@ import 'package:friend_private/pages/home/device.dart'; import 'package:friend_private/pages/memories/page.dart'; import 'package:friend_private/pages/plugins/page.dart'; import 'package:friend_private/pages/settings/page.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; import 'package:friend_private/providers/home_provider.dart'; @@ -21,6 +22,7 @@ import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/providers/message_provider.dart'; import 'package:friend_private/providers/plugin_provider.dart'; import 'package:friend_private/services/notification_service.dart'; +import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; import 'package:friend_private/utils/audio/foreground.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -53,7 +55,6 @@ class _HomePageWrapperState extends State { context.read().periodicConnect('coming from HomePageWrapper'); await context.read().getInitialMemories(); context.read().setSelectedChatPluginId(null); - await context.read().setupHasSpeakerProfile(); }); super.initState(); } @@ -133,6 +134,10 @@ class _HomePageState extends State with WidgetsBindingObserver, Ticker ForegroundUtil.startForegroundTask(); if (mounted) { await context.read().setUserPeople(); + + // Start stream recording + await Provider.of(context, listen: false) + .streamDeviceRecording(device: context.read().connectedDevice); } }); @@ -545,7 +550,9 @@ class _HomePageState extends State with WidgetsBindingObserver, Ticker if (language != SharedPreferencesUtil().recordingsLanguage || hasSpeech != SharedPreferencesUtil().hasSpeakerProfile || transcriptModel != SharedPreferencesUtil().transcriptionModel) { - context.read().restartWebSocket(); + if (context.mounted) { + context.read().onRecordProfileSettingChanged(); + } } }, ), diff --git a/app/lib/pages/memories/widgets/capture.dart b/app/lib/pages/memories/widgets/capture.dart index d94468dbc..76196a8ef 100644 --- a/app/lib/pages/memories/widgets/capture.dart +++ b/app/lib/pages/memories/widgets/capture.dart @@ -34,13 +34,15 @@ class LiteCaptureWidgetState extends State void _onReceiveTaskData(dynamic data) { if (data is Map) { if (data.containsKey('latitude') && data.containsKey('longitude')) { - context.read().setGeolocation(Geolocation( - latitude: data['latitude'], - longitude: data['longitude'], - accuracy: data['accuracy'], - altitude: data['altitude'], - time: DateTime.parse(data['time']), - )); + if (mounted) { + context.read().setGeolocation(Geolocation( + latitude: data['latitude'], + longitude: data['longitude'], + accuracy: data['accuracy'], + altitude: data['altitude'], + time: DateTime.parse(data['time']), + )); + } } else { if (mounted) { context.read().setGeolocation(null); @@ -73,7 +75,6 @@ class LiteCaptureWidgetState extends State @override void dispose() { WidgetsBinding.instance.removeObserver(this); - // context.read().closeWebSocket(); super.dispose(); } @@ -136,11 +137,9 @@ class LiteCaptureWidgetState extends State builder: (c) => getDialog( context, () async { - context.read().closeWebSocketWithoutReconnect('Firmware change detected'); var connectedDevice = deviceProvider.connectedDevice; var codec = await _getAudioCodec(connectedDevice!.id); - context.read().resetState(restartBytesProcessing: true); - context.read().initiateWebsocket(codec); + await context.read().changeAudioRecordProfile(codec); if (Navigator.canPop(context)) { Navigator.pop(context); } diff --git a/app/lib/pages/memories/widgets/processing_capture.dart b/app/lib/pages/memories/widgets/processing_capture.dart index dfb1e5faf..90a39a1fc 100644 --- a/app/lib/pages/memories/widgets/processing_capture.dart +++ b/app/lib/pages/memories/widgets/processing_capture.dart @@ -7,7 +7,6 @@ import 'package:friend_private/pages/memory_capturing/page.dart'; import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; import 'package:friend_private/utils/enums.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -28,6 +27,7 @@ class MemoryCaptureWidget extends StatefulWidget { } class _MemoryCaptureWidgetState extends State { + @override Widget build(BuildContext context) { return Consumer3( @@ -113,7 +113,7 @@ class _MemoryCaptureWidgetState extends State { _toggleRecording(BuildContext context, CaptureProvider provider) async { var recordingState = provider.recordingState; if (recordingState == RecordingState.record) { - provider.stopStreamRecording(); + await provider.stopStreamRecording(); context.read().cancelMemoryCreationTimer(); await context.read().createMemory(); MixpanelManager().phoneMicRecordingStopped(); @@ -128,8 +128,7 @@ class _MemoryCaptureWidgetState extends State { () async { Navigator.pop(context); provider.updateRecordingState(RecordingState.initialising); - context.read().closeWebSocketWithoutReconnect('Recording with phone mic'); - await provider.initiateWebsocket(BleAudioCodec.pcm16, 16000); + await provider.changeAudioRecordProfile(BleAudioCodec.pcm16, 16000); await provider.streamRecording(); MixpanelManager().phoneMicRecordingStarted(); }, @@ -155,9 +154,12 @@ class _MemoryCaptureWidgetState extends State { } else if (captureProvider.memoryCreating) { stateText = "Processing"; isConnected = deviceProvider.connectedDevice != null; - } else if (deviceProvider.connectedDevice != null || captureProvider.recordingState == RecordingState.record) { + } else if (captureProvider.recordingDeviceServiceReady && captureProvider.transcriptServiceReady) { stateText = "Listening"; isConnected = true; + } else if (captureProvider.recordingDeviceServiceReady || captureProvider.transcriptServiceReady) { + stateText = "Preparing"; + isConnected = true; } var isUsingPhoneMic = captureProvider.recordingState == RecordingState.record || diff --git a/app/lib/pages/memory_capturing/page.dart b/app/lib/pages/memory_capturing/page.dart index f9a1968e2..bb5ca9d19 100644 --- a/app/lib/pages/memory_capturing/page.dart +++ b/app/lib/pages/memory_capturing/page.dart @@ -77,7 +77,7 @@ class _MemoryCapturingPageState extends State with TickerPr const SizedBox(width: 4), const Text("🎙️"), const SizedBox(width: 4), - const Expanded(child: Text("In Progress")), + const Expanded(child: Text("In progress")), ], ), ), diff --git a/app/lib/pages/memory_detail/compare_transcripts.dart b/app/lib/pages/memory_detail/compare_transcripts.dart index 913f8a677..86a9e0fb5 100644 --- a/app/lib/pages/memory_detail/compare_transcripts.dart +++ b/app/lib/pages/memory_detail/compare_transcripts.dart @@ -35,7 +35,7 @@ class _CompareTranscriptsPageState extends State { backgroundColor: Theme.of(context).colorScheme.primary, ), body: DefaultTabController( - length: 3, + length: 4, initialIndex: 0, child: Column( children: [ @@ -50,7 +50,12 @@ class _CompareTranscriptsPageState extends State { padding: EdgeInsets.zero, indicatorPadding: EdgeInsets.zero, labelStyle: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 18), - tabs: const [Tab(text: 'Deepgram'), Tab(text: 'Soniox'), Tab(text: 'Whisper-x')], + tabs: const [ + Tab(text: 'Deepgram'), + Tab(text: 'Soniox'), + Tab(text: 'SpeechMatics'), + Tab(text: 'Whisper-x'), + ], indicator: BoxDecoration(color: Colors.transparent, borderRadius: BorderRadius.circular(16)), ), Expanded( @@ -84,6 +89,18 @@ class _CompareTranscriptsPageState extends State { ) ], ), + ListView( + shrinkWrap: true, + children: [ + TranscriptWidget( + segments: transcripts?.speechmatics ?? [], + horizontalMargin: false, + topMargin: false, + canDisplaySeconds: true, + isMemoryDetail: true, + ) + ], + ), ListView( shrinkWrap: true, children: [ diff --git a/app/lib/pages/memory_detail/page.dart b/app/lib/pages/memory_detail/page.dart index d538aaa0b..f3375b546 100644 --- a/app/lib/pages/memory_detail/page.dart +++ b/app/lib/pages/memory_detail/page.dart @@ -209,17 +209,13 @@ class _MemoryDetailPageState extends State with TickerProvider return TabBarView( physics: const NeverScrollableScrollPhysics(), children: [ - Consumer( - builder: (context, provider, child) { + Selector( + selector: (context, provider) => provider.memory.source, + builder: (context, source, child) { return ListView( shrinkWrap: true, - children: provider.memory.source == MemorySource.openglass - ? [ - PhotosGridComponent( - photos: provider.photosData, - ), - const SizedBox(height: 32) - ] + children: source == MemorySource.openglass + ? [const PhotosGridComponent(), const SizedBox(height: 32)] : [const TranscriptWidgets()], ); }, @@ -245,17 +241,18 @@ class SummaryTab extends StatelessWidget { @override Widget build(BuildContext context) { return Selector( - selector: (context, provider) => provider.memory.discarded, - builder: (context, isDiscaarded, child) { - return ListView( - shrinkWrap: true, - children: [ - const GetSummaryWidgets(), - isDiscaarded ? const ReprocessDiscardedWidget() : const GetPluginsWidgets(), - const GetGeolocationWidgets(), - ], - ); - }); + selector: (context, provider) => provider.memory.discarded, + builder: (context, isDiscaarded, child) { + return ListView( + shrinkWrap: true, + children: [ + const GetSummaryWidgets(), + isDiscaarded ? const ReprocessDiscardedWidget() : const GetPluginsWidgets(), + const GetGeolocationWidgets(), + ], + ); + }, + ); } } diff --git a/app/lib/pages/memory_detail/widgets.dart b/app/lib/pages/memory_detail/widgets.dart index 98266cbff..8b2b5a200 100644 --- a/app/lib/pages/memory_detail/widgets.dart +++ b/app/lib/pages/memory_detail/widgets.dart @@ -133,67 +133,19 @@ class GetSummaryWidgets extends StatelessWidget { ); }), memory.structured.actionItems.isNotEmpty ? const SizedBox(height: 40) : const SizedBox.shrink(), - memory.structured.events.isNotEmpty - ? Row( - children: [ - Icon(Icons.event, color: Colors.grey.shade300), - const SizedBox(width: 8), - Text( - 'Events', - style: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 26), - ) - ], - ) - : const SizedBox.shrink(), + // memory.structured.events.isNotEmpty && memory.structured.events.where((e) => e.startsAt.isBefore(memory.startedAt!)).isNotEmpty + // ? Row( + // children: [ + // Icon(Icons.event, color: Colors.grey.shade300), + // const SizedBox(width: 8), + // Text( + // 'Events', + // style: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 26), + // ) + // ], + // ) + // : const SizedBox.shrink(), const EventsListWidget(), - // ...memory.structured.events.mapIndexed((idx, event) { - // print(event.toJson()); - // return ListTile( - // contentPadding: EdgeInsets.zero, - // title: Text( - // event.title, - // style: const TextStyle(color: Colors.white, fontSize: 16, fontWeight: FontWeight.w600), - // ), - // subtitle: Padding( - // padding: const EdgeInsets.only(top: 4.0), - // child: Text( - // '${dateTimeFormat('MMM d, yyyy', event.startsAt)} at ${dateTimeFormat('h:mm a', event.startsAt)} ~ ${event.duration} minutes.', - // style: const TextStyle(color: Colors.grey, fontSize: 15), - // ), - // ), - // trailing: IconButton( - // onPressed: event.created - // ? null - // : () { - // var calEnabled = SharedPreferencesUtil().calendarEnabled; - // var calSelected = SharedPreferencesUtil().calendarId.isNotEmpty; - // if (!calEnabled || !calSelected) { - // routeToPage(context, const CalendarPage()); - // ScaffoldMessenger.of(context).showSnackBar(SnackBar( - // content: Text(!calEnabled - // ? 'Enable calendar integration to add events' - // : 'Select a calendar to add events to'), - // )); - // return; - // } - // context.read().updateEventState(true, idx); - // setMemoryEventsState(memory.id, [idx], [true]); - // CalendarUtil().createEvent( - // event.title, - // event.startsAt, - // event.duration, - // description: event.description, - // ); - // ScaffoldMessenger.of(context).showSnackBar( - // const SnackBar( - // content: Text('Event added to calendar'), - // ), - // ); - // }, - // icon: Icon(event.created ? Icons.check : Icons.add, color: Colors.white), - // ), - // ); - // }), memory.structured.events.isNotEmpty ? const SizedBox(height: 40) : const SizedBox.shrink(), ], ); @@ -209,63 +161,99 @@ class EventsListWidget extends StatelessWidget { Widget build(BuildContext context) { return Consumer( builder: (context, provider, child) { - return ListView.builder( - itemCount: provider.memory.structured.events.length, - shrinkWrap: true, - itemBuilder: (context, idx) { - var event = provider.memory.structured.events[idx]; - return ListTile( - contentPadding: EdgeInsets.zero, - title: Text( - event.title, - style: const TextStyle(color: Colors.white, fontSize: 16, fontWeight: FontWeight.w600), - ), - subtitle: Padding( - padding: const EdgeInsets.only(top: 4.0), - child: Text( - '${dateTimeFormat('MMM d, yyyy', event.startsAt)} at ${dateTimeFormat('h:mm a', event.startsAt)} ~ ${event.duration} minutes.', - style: const TextStyle(color: Colors.grey, fontSize: 15), - ), - ), - trailing: IconButton( - onPressed: event.created - ? null - : () { - var calEnabled = SharedPreferencesUtil().calendarEnabled; - var calSelected = SharedPreferencesUtil().calendarId.isNotEmpty; - if (!calEnabled || !calSelected) { - routeToPage(context, const CalendarPage()); - ScaffoldMessenger.of(context).showSnackBar(SnackBar( - content: Text(!calEnabled - ? 'Enable calendar integration to add events' - : 'Select a calendar to add events to'), - )); - return; - } - context.read().updateEventState(true, idx); - setMemoryEventsState(provider.memory.id, [idx], [true]); - CalendarUtil().createEvent( - event.title, - event.startsAt, - event.duration, - description: event.description, - ); - ScaffoldMessenger.of(context).showSnackBar( - const SnackBar( - content: Text('Event added to calendar'), - ), - ); - }, - icon: Icon(event.created ? Icons.check : Icons.add, color: Colors.white), - ), - ); - }, + return Column( + mainAxisSize: MainAxisSize.min, + children: [ + provider.memory.structured.events.isNotEmpty && + !(provider.memory.structured.events + .where((e) => + e.startsAt.isBefore(provider.memory.startedAt!.add(const Duration(hours: 6))) && + e.startsAt.add(Duration(minutes: e.duration)).isBefore(provider.memory.startedAt!)) + .isNotEmpty) + ? Row( + children: [ + Icon(Icons.event, color: Colors.grey.shade300), + const SizedBox(width: 8), + Text( + 'Events', + style: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 26), + ) + ], + ) + : const SizedBox.shrink(), + ListView.builder( + itemCount: provider.memory.structured.events.length, + shrinkWrap: true, + itemBuilder: (context, idx) { + var event = provider.memory.structured.events[idx]; + if (event.startsAt.isBefore(provider.memory.startedAt!.add(const Duration(hours: 6))) && + event.startsAt.add(Duration(minutes: event.duration)).isBefore(provider.memory.startedAt!)) { + return const SizedBox.shrink(); + } + return ListTile( + contentPadding: EdgeInsets.zero, + title: Text( + event.title, + style: const TextStyle(color: Colors.white, fontSize: 16, fontWeight: FontWeight.w600), + ), + subtitle: Padding( + padding: const EdgeInsets.only(top: 4.0), + child: Text( + '${dateTimeFormat('MMM d, yyyy', event.startsAt)} at ${dateTimeFormat('h:mm a', event.startsAt)} ~ ${minutesConversion(event.duration)}.', + style: const TextStyle(color: Colors.grey, fontSize: 15), + ), + ), + trailing: IconButton( + onPressed: event.created + ? null + : () { + var calEnabled = SharedPreferencesUtil().calendarEnabled; + var calSelected = SharedPreferencesUtil().calendarId.isNotEmpty; + if (!calEnabled || !calSelected) { + routeToPage(context, const CalendarPage()); + ScaffoldMessenger.of(context).showSnackBar(SnackBar( + content: Text(!calEnabled + ? 'Enable calendar integration to add events' + : 'Select a calendar to add events to'), + )); + return; + } + context.read().updateEventState(true, idx); + setMemoryEventsState(provider.memory.id, [idx], [true]); + CalendarUtil().createEvent( + event.title, + event.startsAt, + event.duration, + description: event.description, + ); + ScaffoldMessenger.of(context).showSnackBar( + const SnackBar( + content: Text('Event added to calendar'), + ), + ); + }, + icon: Icon(event.created ? Icons.check : Icons.add, color: Colors.white), + ), + ); + }, + ), + ], ); }, ); } } +String minutesConversion(int minutes) { + if (minutes < 60) { + return '$minutes minutes'; + } else if (minutes < 1440) { + return '${minutes / 60} hours'; + } else { + return '${minutes / 1440} days'; + } +} + class GetEditTextField extends StatefulWidget { final bool enabled; final String overview; @@ -503,6 +491,7 @@ class GetPluginsWidgets extends StatelessWidget { }, child: ListView( shrinkWrap: true, + physics: const NeverScrollableScrollPhysics(), children: [ const SizedBox(height: 32), Text( @@ -557,6 +546,7 @@ class GetGeolocationWidgets extends StatelessWidget { return provider.memory.geolocation; }, builder: (context, geolocation, child) { return Column( + crossAxisAlignment: CrossAxisAlignment.start, children: geolocation == null ? [] : [ diff --git a/app/lib/pages/onboarding/memory_created_widget.dart b/app/lib/pages/onboarding/memory_created_widget.dart index cb6e21c90..f6e3fc02d 100644 --- a/app/lib/pages/onboarding/memory_created_widget.dart +++ b/app/lib/pages/onboarding/memory_created_widget.dart @@ -1,4 +1,5 @@ import 'package:flutter/material.dart'; +import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/pages/memories/widgets/memory_list_item.dart'; import 'package:friend_private/pages/memory_detail/memory_detail_provider.dart'; import 'package:friend_private/pages/memory_detail/page.dart'; @@ -9,11 +10,31 @@ import 'package:friend_private/utils/other/temp.dart'; import 'package:gradient_borders/box_borders/gradient_box_border.dart'; import 'package:provider/provider.dart'; -class MemoryCreatedWidget extends StatelessWidget { +Future updateMemoryDetailProvider(BuildContext context, ServerMemory memory) { + return Future.microtask(() { + context.read().addMemory(memory); + context.read().updateMemory(0); + }); +} + +class MemoryCreatedWidget extends StatefulWidget { final VoidCallback goNext; const MemoryCreatedWidget({super.key, required this.goNext}); + @override + State createState() => _MemoryCreatedWidgetState(); +} + +class _MemoryCreatedWidgetState extends State { + @override + void initState() { + WidgetsBinding.instance.addPostFrameCallback((_) async { + await updateMemoryDetailProvider(context, context.read().memory!); + }); + super.initState(); + } + @override Widget build(BuildContext context) { return Padding( @@ -54,10 +75,8 @@ class MemoryCreatedWidget extends StatelessWidget { ), child: MaterialButton( padding: const EdgeInsets.symmetric(horizontal: 32, vertical: 16), - onPressed: () async { - // goNext(); - context.read().addMemory(provider.memory!); - context.read().updateMemory(0); + onPressed: () { + // updateMemoryDetailProvider(context, provider.memory!); MixpanelManager().memoryListItemClicked(provider.memory!, 0); routeToPage(context, MemoryDetailPage(memory: provider.memory!, isFromOnboarding: true)); }, diff --git a/app/lib/pages/onboarding/name/name_widget.dart b/app/lib/pages/onboarding/name/name_widget.dart index f8b3b7db8..73f6691ac 100644 --- a/app/lib/pages/onboarding/name/name_widget.dart +++ b/app/lib/pages/onboarding/name/name_widget.dart @@ -47,7 +47,7 @@ class _NameWidgetState extends State { textAlign: TextAlign.center, textAlignVertical: TextAlignVertical.center, decoration: InputDecoration( - hintText: 'Enter your given name', + hintText: 'How Omi should call you?', // label: const Text('What should Omi call you?'), hintStyle: const TextStyle(fontSize: 14, color: Colors.grey), // border: UnderlineInputBorder( diff --git a/app/lib/pages/onboarding/permissions/permissions_widget.dart b/app/lib/pages/onboarding/permissions/permissions_widget.dart index 019f2ac35..d82999233 100644 --- a/app/lib/pages/onboarding/permissions/permissions_widget.dart +++ b/app/lib/pages/onboarding/permissions/permissions_widget.dart @@ -166,46 +166,44 @@ class _PermissionsWidgetState extends State { await provider.askForBackgroundPermissions(); } } - await Permission.notification.request().then((value) async { - if (value.isGranted) { - provider.updateNotificationPermission(true); - } - if (await Permission.location.serviceStatus.isEnabled) { - await Permission.locationWhenInUse.request().then((value) async { - if (value.isGranted) { - await Permission.locationAlways.request().then((value) async { - print('Location permission: ${value.isGranted}'); + await Permission.notification.request().then( + (value) async { + if (value.isGranted) { + provider.updateNotificationPermission(true); + } + if (await Permission.location.serviceStatus.isEnabled) { + await Permission.locationWhenInUse.request().then( + (value) async { if (value.isGranted) { - provider.setLoading(false); - if (Platform.isAndroid) { - widget.goNext(); - } - provider.updateLocationPermission(true); - } - value.isGranted - ? () { + await Permission.locationAlways.request().then( + (value) async { + if (value.isGranted) { + provider.updateLocationPermission(true); widget.goNext(); provider.setLoading(false); + } else { + Future.delayed(const Duration(milliseconds: 2500), () async { + if (await Permission.locationAlways.status.isGranted) { + provider.updateLocationPermission(true); + } + widget.goNext(); + provider.setLoading(false); + }); } - : Future.delayed(const Duration(milliseconds: 2500), () async { - if (await Permission.locationAlways.status.isGranted) { - provider.updateLocationPermission(true); - } - print('Location permission222222: ${value.isGranted}'); - widget.goNext(); - provider.setLoading(false); - }); - }); - } else { - widget.goNext(); - provider.setLoading(false); - } - }); - } else { - widget.goNext(); - provider.setLoading(false); - } - }); + }, + ); + } else { + widget.goNext(); + provider.setLoading(false); + } + }, + ); + } else { + widget.goNext(); + provider.setLoading(false); + } + }, + ); }, child: const Text( 'Continue', diff --git a/app/lib/pages/onboarding/speech_profile_widget.dart b/app/lib/pages/onboarding/speech_profile_widget.dart index acbb51f98..e05a4f8b8 100644 --- a/app/lib/pages/onboarding/speech_profile_widget.dart +++ b/app/lib/pages/onboarding/speech_profile_widget.dart @@ -3,7 +3,10 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:flutter_provider_utilities/flutter_provider_utilities.dart'; import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/backend/schema/bt_device.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/speech_profile_provider.dart'; +import 'package:friend_private/services/services.dart'; import 'package:friend_private/widgets/dialog.dart'; import 'package:gradient_borders/box_borders/gradient_box_border.dart'; import 'package:provider/provider.dart'; @@ -49,13 +52,35 @@ class _SpeechProfileWidgetState extends State with TickerPr @override Widget build(BuildContext context) { + Future restartDeviceRecording() async { + debugPrint("restartDeviceRecording $mounted"); + + // Restart device recording, clear transcripts + if (mounted) { + Provider.of(context, listen: false).clearTranscripts(); + Provider.of(context, listen: false).streamDeviceRecording( + device: Provider.of(context, listen: false).deviceProvider?.connectedDevice, + ); + } + } + + Future stopDeviceRecording() async { + debugPrint("stopDeviceRecording $mounted"); + + // Restart device recording, clear transcripts + if (mounted) { + await Provider.of(context, listen: false).stopStreamDeviceRecording(); + } + } + return PopScope( canPop: true, - onPopInvoked: (didPop) { + onPopInvoked: (didPop) async { context.read().close(); + restartDeviceRecording(); }, - child: Consumer( - builder: (context, provider, child) { + child: Consumer2( + builder: (context, provider, _, child) { return MessageListener( showInfo: (info) { if (info == 'SCROLL_DOWN') { @@ -204,10 +229,11 @@ class _SpeechProfileWidgetState extends State with TickerPr ), child: TextButton( onPressed: () async { - await provider.initialise(true); + await stopDeviceRecording(); + await provider.initialise(true, finalizedCallback: restartDeviceRecording); provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () async { - provider.finalize(true); + provider.finalize(); }); provider.updateStartedRecording(true); }, diff --git a/app/lib/pages/onboarding/wrapper.dart b/app/lib/pages/onboarding/wrapper.dart index 239e2eb76..b47b9c0f1 100644 --- a/app/lib/pages/onboarding/wrapper.dart +++ b/app/lib/pages/onboarding/wrapper.dart @@ -12,6 +12,7 @@ import 'package:friend_private/pages/onboarding/name/name_widget.dart'; import 'package:friend_private/pages/onboarding/permissions/permissions_widget.dart'; import 'package:friend_private/pages/onboarding/speech_profile_widget.dart'; import 'package:friend_private/pages/onboarding/welcome/page.dart'; +import 'package:friend_private/providers/home_provider.dart'; import 'package:friend_private/providers/onboarding_provider.dart'; import 'package:friend_private/providers/speech_profile_provider.dart'; import 'package:friend_private/services/services.dart'; @@ -39,6 +40,7 @@ class _OnboardingWrapperState extends State with TickerProvid WidgetsBinding.instance.addPostFrameCallback((_) async { if (isSignedIn()) { // && !SharedPreferencesUtil().onboardingCompleted + context.read().setupHasSpeakerProfile(); _goNext(); } }); @@ -51,7 +53,14 @@ class _OnboardingWrapperState extends State with TickerProvid super.dispose(); } - _goNext() => _controller!.animateTo(_controller!.index + 1); + _goNext() { + if (_controller!.index < _controller!.length - 1) { + _controller!.animateTo(_controller!.index + 1); + } else { + routeToPage(context, const HomePageWrapper(), replace: true); + } + // _controller!.animateTo(_controller!.index + 1); + } // TODO: use connection directly Future _getAudioCodec(String deviceId) async { @@ -69,6 +78,7 @@ class _OnboardingWrapperState extends State with TickerProvid AuthComponent( onSignIn: () { MixpanelManager().onboardingStepCompleted('Auth'); + context.read().setupHasSpeakerProfile(); if (SharedPreferencesUtil().onboardingCompleted) { // previous users // Not needed anymore, because AuthProvider already does this @@ -101,7 +111,7 @@ class _OnboardingWrapperState extends State with TickerProvid }, goNext: () async { var provider = context.read(); - if (hasSpeechProfile) { + if (context.read().hasSpeakerProfile) { // previous users routeToPage(context, const HomePageWrapper(), replace: true); } else { diff --git a/app/lib/pages/settings/change_name_widget.dart b/app/lib/pages/settings/change_name_widget.dart new file mode 100644 index 000000000..41060c291 --- /dev/null +++ b/app/lib/pages/settings/change_name_widget.dart @@ -0,0 +1,117 @@ +import 'dart:io'; + +import 'package:firebase_auth/firebase_auth.dart'; +import 'package:flutter/cupertino.dart'; +import 'package:flutter/material.dart'; +import 'package:friend_private/backend/auth.dart'; +import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/utils/alerts/app_snackbar.dart'; + +class ChangeNameWidget extends StatefulWidget { + const ChangeNameWidget({super.key}); + + @override + State createState() => _ChangeNameWidgetState(); +} + +class _ChangeNameWidgetState extends State { + late TextEditingController nameController; + User? user; + bool isSaving = false; + + @override + void initState() { + user = getFirebaseUser(); + nameController = TextEditingController(text: user?.displayName ?? ''); + super.initState(); + } + + @override + Widget build(BuildContext context) { + if (Platform.isIOS) { + return CupertinoAlertDialog( + content: Padding( + padding: const EdgeInsets.all(8.0), + child: Column( + children: [ + const Text('How Omi should call you?'), + const SizedBox(height: 8), + CupertinoTextField( + controller: nameController, + placeholderStyle: const TextStyle(color: Colors.white54), + style: const TextStyle(color: Colors.white), + ), + ], + ), + ), + actions: [ + CupertinoDialogAction( + textStyle: const TextStyle(color: Colors.white), + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text('Cancel'), + ), + CupertinoDialogAction( + textStyle: const TextStyle(color: Colors.white), + onPressed: () { + if (nameController.text.isEmpty || nameController.text.trim().isEmpty) { + AppSnackbar.showSnackbarError('Name cannot be empty'); + return; + } + SharedPreferencesUtil().givenName = nameController.text; + updateGivenName(nameController.text); + AppSnackbar.showSnackbar('Name updated successfully!'); + Navigator.of(context).pop(); + }, + child: const Text('Save'), + ), + ], + ); + } else { + return AlertDialog( + content: Padding( + padding: const EdgeInsets.all(8.0), + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + const Text('How Omi should call you?'), + const SizedBox(height: 8), + TextField( + controller: nameController, + style: const TextStyle(color: Colors.white), + ), + ], + ), + ), + actions: [ + TextButton( + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text( + 'Cancel', + style: TextStyle(color: Colors.white), + ), + ), + TextButton( + onPressed: () { + if (nameController.text.isEmpty || nameController.text.trim().isEmpty) { + AppSnackbar.showSnackbarError('Name cannot be empty'); + return; + } + SharedPreferencesUtil().givenName = nameController.text; + updateGivenName(nameController.text); + AppSnackbar.showSnackbar('Name updated successfully!'); + Navigator.of(context).pop(); + }, + child: const Text( + 'Save', + style: TextStyle(color: Colors.white), + ), + ), + ], + ); + } + } +} diff --git a/app/lib/pages/settings/developer.dart b/app/lib/pages/settings/developer.dart index 7fc52e8bb..3f8bafa05 100644 --- a/app/lib/pages/settings/developer.dart +++ b/app/lib/pages/settings/developer.dart @@ -127,7 +127,7 @@ class __DeveloperSettingsPageState extends State<_DeveloperSettingsPage> { underline: Container(height: 0, color: Colors.white), isExpanded: true, itemHeight: 48, - items: ['deepgram', 'soniox'].map>((String value) { + items: ['deepgram', 'soniox', 'speechmatics'].map>((String value) { return DropdownMenuItem( value: value, child: Text( diff --git a/app/lib/pages/settings/personal_details.dart b/app/lib/pages/settings/personal_details.dart deleted file mode 100644 index 7594afe76..000000000 --- a/app/lib/pages/settings/personal_details.dart +++ /dev/null @@ -1,141 +0,0 @@ -import 'package:firebase_auth/firebase_auth.dart'; -import 'package:flutter/material.dart'; -import 'package:friend_private/backend/auth.dart'; -import 'package:friend_private/utils/alerts/app_snackbar.dart'; -import 'package:gradient_borders/gradient_borders.dart'; - -class PersonalDetails extends StatefulWidget { - const PersonalDetails({super.key}); - - @override - State createState() => _PersonalDetailsState(); -} - -class _PersonalDetailsState extends State { - late TextEditingController nameController; - User? user; - bool isSaving = false; - - @override - void initState() { - user = getFirebaseUser(); - nameController = TextEditingController(text: user?.displayName ?? ''); - super.initState(); - } - - @override - Widget build(BuildContext context) { - return Scaffold( - backgroundColor: Theme.of(context).colorScheme.primary, - appBar: AppBar( - backgroundColor: Theme.of(context).colorScheme.primary, - actions: [ - MaterialButton( - onPressed: () async { - if (nameController.text.isEmpty || nameController.text.trim().isEmpty) { - AppSnackbar.showSnackbarError('Name cannot be empty'); - return; - } - setState(() => isSaving = true); - await updateGivenName(nameController.text); - setState(() => isSaving = false); - AppSnackbar.showSnackbar('Name updated successfully!'); - Navigator.of(context).pop(); - }, - color: Colors.transparent, - elevation: 0, - child: isSaving - ? const Center( - child: Padding( - padding: EdgeInsets.all(8.0), - child: CircularProgressIndicator( - color: Colors.white, - ), - ), - ) - : const Padding( - padding: EdgeInsets.symmetric(horizontal: 4.0), - child: Text( - 'Save', - style: TextStyle(color: Colors.deepPurple, fontWeight: FontWeight.w600, fontSize: 16), - ), - ), - ) - ], - ), - body: Padding( - padding: const EdgeInsets.only(left: 18, right: 18), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const SizedBox( - height: 30, - ), - TextFormField( - enabled: true, - controller: nameController, - // textCapitalization: TextCapitalization.sentences, - obscureText: false, - // canRequestFocus: true, - textAlign: TextAlign.start, - textAlignVertical: TextAlignVertical.center, - decoration: InputDecoration( - hintText: 'Enter your full name', - hintStyle: const TextStyle(fontSize: 14.0, color: Colors.grey), - floatingLabelBehavior: FloatingLabelBehavior.always, - label: Text( - 'Given Name', - style: TextStyle( - color: Colors.grey.shade200, - fontSize: 16, - ), - ), - border: GradientOutlineInputBorder( - borderRadius: BorderRadius.circular(8), - gradient: const LinearGradient( - colors: [ - Color.fromARGB(255, 202, 201, 201), - Color.fromARGB(255, 159, 158, 158), - ], - ), - ), - ), - style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), - ), - const SizedBox( - height: 30, - ), - // TextFormField( - // enabled: true, - // obscureText: false, - // textAlign: TextAlign.start, - // textAlignVertical: TextAlignVertical.center, - // readOnly: true, - // initialValue: user?.email, - // decoration: InputDecoration( - // floatingLabelBehavior: FloatingLabelBehavior.always, - // label: Text( - // 'Email Address', - // style: TextStyle( - // color: Colors.grey.shade200, - // fontSize: 16, - // ), - // ), - // border: GradientOutlineInputBorder( - // borderRadius: BorderRadius.circular(8), - // gradient: const LinearGradient( - // colors: [ - // Color.fromARGB(255, 202, 201, 201), - // Color.fromARGB(255, 159, 158, 158), - // ], - // ), - // ), - // ), - // style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), - // ), - ], - ), - ), - ); - } -} diff --git a/app/lib/pages/settings/profile.dart b/app/lib/pages/settings/profile.dart index 8c776eba6..631325ec6 100644 --- a/app/lib/pages/settings/profile.dart +++ b/app/lib/pages/settings/profile.dart @@ -2,7 +2,7 @@ import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/pages/facts/page.dart'; -import 'package:friend_private/pages/settings/personal_details.dart'; +import 'package:friend_private/pages/settings/change_name_widget.dart'; import 'package:friend_private/pages/settings/privacy.dart'; import 'package:friend_private/pages/settings/recordings_storage_permission.dart'; import 'package:friend_private/pages/speech_profile/page.dart'; @@ -66,14 +66,14 @@ class _ProfilePageState extends State { ), subtitle: Text(SharedPreferencesUtil().givenName.isEmpty ? 'Not set' : SharedPreferencesUtil().givenName), trailing: const Icon(Icons.person, size: 20), - onTap: () { - // TODO: change to a dialog + onTap: () async { MixpanelManager().pageOpened('Profile Change Name'); - showModalBottomSheet( - context: context, - builder: (c) { - return const PersonalDetails(); - }); + await showDialog( + context: context, + builder: (BuildContext context) { + return const ChangeNameWidget(); + }, + ).whenComplete(() => setState(() {})); }, ), const SizedBox(height: 24), diff --git a/app/lib/pages/speech_profile/page.dart b/app/lib/pages/speech_profile/page.dart index 99a66ef9a..198c27234 100644 --- a/app/lib/pages/speech_profile/page.dart +++ b/app/lib/pages/speech_profile/page.dart @@ -6,6 +6,7 @@ import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/pages/home/page.dart'; import 'package:friend_private/pages/speech_profile/user_speech_samples.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/speech_profile_provider.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -30,6 +31,7 @@ class _SpeechProfilePageState extends State with TickerProvid WidgetsBinding.instance.addPostFrameCallback((timeStamp) async { await context.read().updateDevice(); }); + super.initState(); } @@ -65,16 +67,36 @@ class _SpeechProfilePageState extends State with TickerProvid @override Widget build(BuildContext context) { + Future restartDeviceRecording() async { + debugPrint("restartDeviceRecording $mounted"); + if (mounted) { + Provider.of(context, listen: false).clearTranscripts(); + Provider.of(context, listen: false).streamDeviceRecording( + device: Provider.of(context, listen: false).deviceProvider?.connectedDevice, + ); + } + } + + Future stopDeviceRecording() async { + debugPrint("stopDeviceRecording $mounted"); + if (mounted) { + await Provider.of(context, listen: false).stopStreamDeviceRecording(); + } + } + return PopScope( canPop: true, onPopInvoked: (didPop) { if (context.read().isInitialised) { WidgetsBinding.instance.addPostFrameCallback((timeStamp) async { await context.read().close(); + + // Restart device recording + restartDeviceRecording(); }); } }, - child: Consumer(builder: (context, provider, child) { + child: Consumer2(builder: (context, provider, _, child) { return MessageListener( showInfo: (info) { if (info == 'SCROLL_DOWN') { @@ -320,12 +342,13 @@ class _SpeechProfilePageState extends State with TickerProvid ); return; } - await provider.initialise(false); - // provider.initiateWebsocket(false); + + await stopDeviceRecording(); + await provider.initialise(false, finalizedCallback: restartDeviceRecording); // 1.5 minutes seems reasonable provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () { - provider.finalize(false); + provider.finalize(); }); provider.updateStartedRecording(true); }, diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 09c80c79b..60ffc8e88 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -20,7 +20,6 @@ import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/pages/capture/logic/openglass_mixin.dart'; import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/providers/message_provider.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/services/notification_service.dart'; import 'package:friend_private/utils/analytics/growthbook.dart'; @@ -32,23 +31,26 @@ import 'package:friend_private/utils/features/calendar.dart'; import 'package:friend_private/utils/logger.dart'; import 'package:friend_private/utils/memories/integrations.dart'; import 'package:friend_private/utils/memories/process.dart'; -import 'package:friend_private/utils/websockets.dart'; +import 'package:friend_private/utils/pure_socket.dart'; import 'package:permission_handler/permission_handler.dart'; import 'package:uuid/uuid.dart'; -class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifierMixin { +class CaptureProvider extends ChangeNotifier + with OpenGlassMixin, MessageNotifierMixin + implements ITransctipSegmentSocketServiceListener { MemoryProvider? memoryProvider; MessageProvider? messageProvider; - WebSocketProvider? webSocketProvider; + TranscripSegmentSocketService? _socket; - void updateProviderInstances(MemoryProvider? mp, MessageProvider? p, WebSocketProvider? wsProvider) { + Timer? _keepAliveTimer; + + void updateProviderInstances(MemoryProvider? mp, MessageProvider? p) { memoryProvider = mp; messageProvider = p; - webSocketProvider = wsProvider; notifyListeners(); } - BTDeviceStruct? connectedDevice; + BTDeviceStruct? _recordingDevice; bool isGlasses = false; List segments = []; @@ -70,6 +72,11 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie RecordingState recordingState = RecordingState.stop; + bool _transcriptServiceReady = false; + bool get transcriptServiceReady => _transcriptServiceReady; + + bool get recordingDeviceServiceReady => _recordingDevice != null || recordingState == RecordingState.record; + // ----------------------- // Memory creation variables double? streamStartedAtSecond; @@ -95,14 +102,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie String? processingMemoryId; String btConnectedTime = ""; - bool resetStateAlreadyCalled = false; String dateTimeStorageString = ""; - void setResetStateAlreadyCalled(bool value) { - resetStateAlreadyCalled = value; - notifyListeners(); - } - void setHasTranscripts(bool value) { hasTranscripts = value; notifyListeners(); @@ -119,7 +120,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie } void setMemoryCreating(bool value) { - print('set memory creating ${value}'); + debugPrint('set memory creating ${value}'); memoryCreating = value; notifyListeners(); } @@ -140,9 +141,9 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie notifyListeners(); } - void updateConnectedDevice(BTDeviceStruct? device) { - debugPrint('connected device changed from ${connectedDevice?.id} to ${device?.id}'); - connectedDevice = device; + void _updateRecordingDevice(BTDeviceStruct? device) { + debugPrint('connected device changed from ${_recordingDevice?.id} to ${device?.id}'); + _recordingDevice = device; notifyListeners(); } @@ -163,7 +164,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie emotionalFeedback: GrowthbookUtil().isOmiFeedbackEnabled(), ); if (result?.result == null) { - print("Can not update processing memory, result null"); + debugPrint("Can not update processing memory, result null"); } } @@ -174,7 +175,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future _onMemoryCreated(ServerMessageEvent event) async { if (event.memory == null) { - print("Memory is not found, processing memory ${event.processingMemoryId}"); + debugPrint("Memory is not found, processing memory ${event.processingMemoryId}"); return; } _processOnMemoryCreated(event.memory, event.messages ?? []); @@ -187,7 +188,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future _onMemoryPostProcessSuccess(String memoryId) async { var memory = await getMemoryById(memoryId); if (memory == null) { - print("Memory is not found $memoryId"); + debugPrint("Memory is not found $memoryId"); return; } @@ -197,7 +198,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future _onMemoryPostProcessFailed(String memoryId) async { var memory = await getMemoryById(memoryId); if (memory == null) { - print("Memory is not found $memoryId"); + debugPrint("Memory is not found $memoryId"); return; } @@ -228,6 +229,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie // Notify setMemoryCreating(false); setHasTranscripts(false); + _handleCalendarCreation(memory); notifyListeners(); return; } @@ -314,7 +316,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie return true; } - void _cleanNew() async { + Future _clean() async { segments = []; audioStorage?.clearAudioBytes(); @@ -327,11 +329,15 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie photos = []; conversationId = const Uuid().v4(); processingMemoryId = null; + } + + Future _cleanNew() async { + _clean(); // Create new socket session // Warn: should have a better solution to keep the socket alived - await webSocketProvider?.closeWebSocketWithoutReconnect('reset new memory session'); - await initiateWebsocket(); + debugPrint("_cleanNew"); + await _initiateWebsocket(force: true); } _handleCalendarCreation(ServerMemory memory) { @@ -344,6 +350,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie List indexes = events.mapIndexed((index, e) => index).toList(); setMemoryEventsState(memory.id, indexes, indexes.map((_) => true).toList()); for (var i = 0; i < events.length; i++) { + if (events[i].created) continue; events[i].created = true; CalendarUtil().createEvent( events[i].title, @@ -354,128 +361,48 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie } } - Future initiateWebsocket([ + Future onRecordProfileSettingChanged() async { + await _resetState(restartBytesProcessing: true); + } + + Future changeAudioRecordProfile([ BleAudioCodec? audioCodec, int? sampleRate, ]) async { - // setWebSocketConnecting(true); - print('initiateWebsocket in capture_provider'); - BleAudioCodec codec = audioCodec ?? SharedPreferencesUtil().deviceCodec; - sampleRate ??= (codec == BleAudioCodec.opus ? 16000 : 8000); - print('is ws null: ${webSocketProvider == null}'); - await webSocketProvider?.initWebSocket( - codec: codec, - sampleRate: sampleRate, - includeSpeechProfile: true, - newMemoryWatch: true, - // Warn: need clarify about initiateWebsocket - onConnectionSuccess: () { - print('inside onConnectionSuccess'); - if (segments.isNotEmpty) { - // means that it was a reconnection, so we need to reset - streamStartedAtSecond = null; - secondsMissedOnReconnect = (DateTime.now().difference(firstStreamReceivedAt!).inSeconds); - } - print('bottom in onConnectionSuccess'); - notifyListeners(); - }, - onConnectionFailed: (err) { - print('inside onConnectionFailed'); - print('err: $err'); - notifyListeners(); - }, - onConnectionClosed: (int? closeCode, String? closeReason) { - print('inside onConnectionClosed'); - print('closeCode: $closeCode'); - // connection was closed, either on resetState, or by backend, or by some other reason. - // setState(() {}); - }, - onConnectionError: (err) { - print('inside onConnectionError'); - print('err: $err'); - // connection was okay, but then failed. - notifyListeners(); - }, - onMessageEventReceived: (ServerMessageEvent event) { - if (event.type == MessageEventType.newMemoryCreating) { - _onMemoryCreating(); - return; - } - - if (event.type == MessageEventType.newMemoryCreated) { - _onMemoryCreated(event); - return; - } + debugPrint("changeAudioRecordProfile"); + await _resetState(restartBytesProcessing: true); + await _initiateWebsocket(audioCodec: audioCodec, sampleRate: sampleRate); + } - if (event.type == MessageEventType.newMemoryCreateFailed) { - _onMemoryCreateFailed(); - return; - } + Future _initiateWebsocket({ + BleAudioCodec? audioCodec, + int? sampleRate, + bool force = false, + }) async { + debugPrint('initiateWebsocket in capture_provider'); - if (event.type == MessageEventType.newProcessingMemoryCreated) { - if (event.processingMemoryId == null) { - print("New processing memory created message event is invalid"); - return; - } - _onProcessingMemoryCreated(event.processingMemoryId!); - return; - } + BleAudioCodec codec = audioCodec ?? SharedPreferencesUtil().deviceCodec; + sampleRate ??= (codec == BleAudioCodec.opus ? 16000 : 8000); - if (event.type == MessageEventType.memoryPostProcessingSuccess) { - if (event.memoryId == null) { - print("Post proccess message event is invalid"); - return; - } - _onMemoryPostProcessSuccess(event.memoryId!); - return; - } + debugPrint('is ws null: ${_socket == null}'); - if (event.type == MessageEventType.memoryPostProcessingFailed) { - if (event.memoryId == null) { - print("Post proccess message event is invalid"); - return; - } - _onMemoryPostProcessFailed(event.memoryId!); - return; - } - }, - onMessageReceived: (List newSegments) { - if (newSegments.isEmpty) return; - - if (segments.isEmpty) { - debugPrint('newSegments: ${newSegments.last}'); - // TODO: small bug -> when memory A creates, and memory B starts, memory B will clean a lot more seconds than available, - // losing from the audio the first part of the recording. All other parts are fine. - FlutterForegroundTask.sendDataToTask(jsonEncode({'location': true})); - var currentSeconds = (audioStorage?.frames.length ?? 0) ~/ 100; - var removeUpToSecond = newSegments[0].start.toInt(); - audioStorage?.removeFramesRange(fromSecond: 0, toSecond: min(max(currentSeconds - 5, 0), removeUpToSecond)); - firstStreamReceivedAt = DateTime.now(); - } + // Get memory socket + _socket = await ServiceManager.instance().socket.memory(codec: codec, sampleRate: sampleRate, force: force); + if (_socket == null) { + throw Exception("Can not create new memory socket"); + } + _socket?.subscribe(this, this); + _transcriptServiceReady = true; - streamStartedAtSecond ??= newSegments[0].start; - TranscriptSegment.combineSegments( - segments, - newSegments, - toRemoveSeconds: streamStartedAtSecond ?? 0, - toAddSeconds: secondsMissedOnReconnect ?? 0, - ); - triggerTranscriptSegmentReceivedEvents(newSegments, conversationId, sendMessageToChat: (v) { - messageProvider?.addMessage(v); - }); - - debugPrint('Memory creation timer restarted'); - _memoryCreationTimer?.cancel(); - _memoryCreationTimer = - Timer(const Duration(seconds: quietSecondsForMemoryCreation), () => _createPhotoCharacteristicMemory()); - setHasTranscripts(true); - notifyListeners(); - }, - ); + if (segments.isNotEmpty) { + // means that it was a reconnection, so we need to reset + streamStartedAtSecond = null; + secondsMissedOnReconnect = (DateTime.now().difference(firstStreamReceivedAt!).inSeconds); + } } Future streamAudioToWs(String id, BleAudioCodec codec) async { - print('streamAudioToWs in capture_provider'); + debugPrint('streamAudioToWs in capture_provider'); audioStorage = WavBytesUtil(codec: codec); if (_bleBytesStream != null) { _bleBytesStream?.cancel(); @@ -488,8 +415,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie final trimmedValue = value.sublist(3); // TODO: if this (0,3) is not removed, deepgram can't seem to be able to detect the audio. // https://developers.deepgram.com/docs/determining-your-audio-format-for-live-streaming-audio - if (webSocketProvider?.wsConnectionState == WebsocketConnectionStatus.connected) { - webSocketProvider?.websocketChannel?.sink.add(trimmedValue); + if (_socket?.state == SocketServiceState.connected) { + _socket?.send(trimmedValue); } }, ); @@ -595,19 +522,20 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future getFileFromDevice(int fileNum,int offset) async { storageUtil.fileNum = fileNum; int command = 0; - writeToStorage(connectedDevice!.id, storageUtil.fileNum, command,offset); + writeToStorage(_recordingDevice!.id, storageUtil.fileNum, command,offset); + } Future clearFileFromDevice(int fileNum) async { storageUtil.fileNum = fileNum; int command = 1; - writeToStorage(connectedDevice!.id, storageUtil.fileNum, command,0); + writeToStorage(_recordingDevice!.id, storageUtil.fileNum, command,0); } Future pauseFileFromDevice(int fileNum) async { storageUtil.fileNum = fileNum; int command = 3; - writeToStorage(connectedDevice!.id, storageUtil.fileNum, command,0); + writeToStorage(_recordingDevice!.id, storageUtil.fileNum, command,0); } void setStorageIsReady(bool value) { @@ -623,40 +551,30 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future resetForSpeechProfile() async { closeBleStream(); - await webSocketProvider?.closeWebSocketWithoutReconnect('reset for speech profile'); + await _socket?.stop(reason: 'reset for speech profile'); setAudioBytesConnected(false); notifyListeners(); } - Future resetState({ + Future _resetState({ bool restartBytesProcessing = true, - bool isFromSpeechProfile = false, - BTDeviceStruct? btDevice, }) async { - if (resetStateAlreadyCalled) { - debugPrint('resetState already called'); - return; - } - setResetStateAlreadyCalled(true); - debugPrint('resetState: restartBytesProcessing=$restartBytesProcessing, isFromSpeechProfile=$isFromSpeechProfile'); + debugPrint('resetState: restartBytesProcessing=$restartBytesProcessing'); _cleanupCurrentState(); - await startOpenGlass(); - if (!isFromSpeechProfile) { - await _handleMemoryCreation(restartBytesProcessing); - } - - bool codecChanged = await _checkCodecChange(); + await _handleMemoryCreation(restartBytesProcessing); - if (restartBytesProcessing || codecChanged) { - await _manageWebSocketConnection(codecChanged, isFromSpeechProfile); - } + await _recheckCodecChange(); + await _ensureSocketConnection(force: true); - await initiateFriendAudioStreaming(isFromSpeechProfile); + await startOpenGlass(); + await _initiateFriendAudioStreaming(); // TODO: Commenting this for now as DevKit 2 is not yet used in production - await initiateStorageBytesStreaming(); + + // await initiateStorageBytesStreaming(); //enable when ready setResetStateAlreadyCalled(false); + notifyListeners(); } @@ -733,44 +651,44 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie return connection.hasPhotoStreamingCharacteristic(); } - Future _checkCodecChange() async { - if (connectedDevice != null) { - BleAudioCodec newCodec = await _getAudioCodec(connectedDevice!.id); + Future _recheckCodecChange() async { + if (_recordingDevice != null) { + BleAudioCodec newCodec = await _getAudioCodec(_recordingDevice!.id); if (SharedPreferencesUtil().deviceCodec != newCodec) { debugPrint('Device codec changed from ${SharedPreferencesUtil().deviceCodec} to $newCodec'); - SharedPreferencesUtil().deviceCodec = newCodec; + await SharedPreferencesUtil().setDeviceCodec(newCodec); return true; } } return false; } - Future _manageWebSocketConnection(bool codecChanged, bool isFromSpeechProfile) async { - if (codecChanged || webSocketProvider?.wsConnectionState != WebsocketConnectionStatus.connected) { - await webSocketProvider?.closeWebSocketWithoutReconnect('reset state $isFromSpeechProfile'); - // if (!isFromSpeechProfile) { - await initiateWebsocket(); - // } + Future _ensureSocketConnection({bool force = false}) async { + debugPrint("_ensureSocketConnection"); + var codec = SharedPreferencesUtil().deviceCodec; + if (force || (codec != _socket?.codec || _socket?.state != SocketServiceState.connected)) { + await _socket?.stop(reason: 'reset state, force $force'); + await _initiateWebsocket(force: force); } } - Future initiateFriendAudioStreaming(bool isFromSpeechProfile) async { - print('connectedDevice: $connectedDevice in initiateFriendAudioStreaming'); - if (connectedDevice == null) return; + Future _initiateFriendAudioStreaming() async { + debugPrint('_recordingDevice: $_recordingDevice in initiateFriendAudioStreaming'); + if (_recordingDevice == null) return; - BleAudioCodec codec = await _getAudioCodec(connectedDevice!.id); + BleAudioCodec codec = await _getAudioCodec(_recordingDevice!.id); if (SharedPreferencesUtil().deviceCodec != codec) { debugPrint('Device codec changed from ${SharedPreferencesUtil().deviceCodec} to $codec'); SharedPreferencesUtil().deviceCodec = codec; notifyInfo('FIM_CHANGE'); - await _manageWebSocketConnection(true, isFromSpeechProfile); + await _ensureSocketConnection(); } - // Why is the connectedDevice null at this point? + // Why is the _recordingDevice null at this point? if (!audioBytesConnected) { - if (connectedDevice != null) { - await streamAudioToWs(connectedDevice!.id, codec); - //here we will start the sd card websocket + + if (_recordingDevice != null) { + await streamAudioToWs(_recordingDevice!.id, codec); } else { // Is the app in foreground when this happens? Logger.handle(Exception('Device Not Connected'), StackTrace.current, @@ -783,8 +701,9 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future initiateStorageBytesStreaming() async { debugPrint('initiateStorageBytesStreaming'); - if (connectedDevice == null) return; - currentStorageFiles = await _getStorageList(connectedDevice!.id); + + if (_recordingDevice == null) return; + currentStorageFiles = await _getStorageList(_recordingDevice!.id); if (currentStorageFiles.isEmpty) { debugPrint('No storage files found'); return; @@ -827,11 +746,11 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie } Future startOpenGlass() async { - if (connectedDevice == null) return; - isGlasses = await _hasPhotoStreamingCharacteristic(connectedDevice!.id); + if (_recordingDevice == null) return; + isGlasses = await _hasPhotoStreamingCharacteristic(_recordingDevice!.id); if (!isGlasses) return; - await openGlassProcessing(connectedDevice!, (p) {}, setHasTranscripts); - webSocketProvider?.closeWebSocketWithoutReconnect('reset state open glass'); + await openGlassProcessing(_recordingDevice!, (p) {}, setHasTranscripts); + _socket?.stop(reason: 'reset state open glass'); notifyListeners(); } @@ -849,6 +768,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie void dispose() { _bleBytesStream?.cancel(); _memoryCreationTimer?.cancel(); + _socket?.unsubscribe(this); + _keepAliveTimer?.cancel(); super.dispose(); } @@ -862,8 +783,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie // record await ServiceManager.instance().mic.start(onByteReceived: (bytes) { - if (webSocketProvider?.wsConnectionState == WebsocketConnectionStatus.connected) { - webSocketProvider?.websocketChannel?.sink.add(bytes); + if (_socket?.state == SocketServiceState.connected) { + _socket?.send(bytes); } }, onRecording: () { updateRecordingState(RecordingState.record); @@ -877,4 +798,150 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie stopStreamRecording() { ServiceManager.instance().mic.stop(); } + + Future streamDeviceRecording({ + BTDeviceStruct? device, + bool restartBytesProcessing = true, + }) async { + debugPrint("streamDeviceRecording ${device} ${restartBytesProcessing}"); + if (device != null) { + _updateRecordingDevice(device); + } + + await _resetState( + restartBytesProcessing: restartBytesProcessing, + ); + } + + Future stopStreamDeviceRecording({bool cleanDevice = false}) async { + if (cleanDevice) { + _updateRecordingDevice(null); + } + _cleanupCurrentState(); + await _socket?.stop(reason: 'stop stream device recording'); + await _handleMemoryCreation(false); + } + + // Socket handling + + @override + void onClosed() { + _transcriptServiceReady = false; + debugPrint('[Provider] Socket is closed'); + + _clean(); + + // Notify + setMemoryCreating(false); + setHasTranscripts(false); + notifyListeners(); + + // Keep alived + _startKeepAlivedServices(); + } + + void _startKeepAlivedServices() { + if (_recordingDevice != null && _socket?.state != SocketServiceState.connected) { + _keepAliveTimer?.cancel(); + _keepAliveTimer = Timer.periodic(const Duration(seconds: 15), (t) async { + debugPrint("[Provider] keep alived..."); + + if (_recordingDevice == null || _socket?.state == SocketServiceState.connected) { + t.cancel(); + return; + } + + await _initiateWebsocket(); + }); + } + } + + @override + void onError(Object err) { + _transcriptServiceReady = false; + debugPrint('err: $err'); + notifyListeners(); + + // Keep alived + _startKeepAlivedServices(); + } + + @override + void onMessageEventReceived(ServerMessageEvent event) { + if (event.type == MessageEventType.newMemoryCreating) { + _onMemoryCreating(); + return; + } + + if (event.type == MessageEventType.newMemoryCreated) { + _onMemoryCreated(event); + return; + } + + if (event.type == MessageEventType.newMemoryCreateFailed) { + _onMemoryCreateFailed(); + return; + } + + if (event.type == MessageEventType.newProcessingMemoryCreated) { + if (event.processingMemoryId == null) { + debugPrint("New processing memory created message event is invalid"); + return; + } + _onProcessingMemoryCreated(event.processingMemoryId!); + return; + } + + if (event.type == MessageEventType.memoryPostProcessingSuccess) { + if (event.memoryId == null) { + debugPrint("Post proccess message event is invalid"); + return; + } + _onMemoryPostProcessSuccess(event.memoryId!); + return; + } + + if (event.type == MessageEventType.memoryPostProcessingFailed) { + if (event.memoryId == null) { + debugPrint("Post proccess message event is invalid"); + return; + } + _onMemoryPostProcessFailed(event.memoryId!); + return; + } + } + + @override + void onSegmentReceived(List newSegments) { + if (newSegments.isEmpty) return; + + if (segments.isEmpty) { + debugPrint('newSegments: ${newSegments.last}'); + // TODO: small bug -> when memory A creates, and memory B starts, memory B will clean a lot more seconds than available, + // losing from the audio the first part of the recording. All other parts are fine. + FlutterForegroundTask.sendDataToTask(jsonEncode({'location': true})); + var currentSeconds = (audioStorage?.frames.length ?? 0) ~/ 100; + var removeUpToSecond = newSegments[0].start.toInt(); + audioStorage?.removeFramesRange(fromSecond: 0, toSecond: min(max(currentSeconds - 5, 0), removeUpToSecond)); + firstStreamReceivedAt = DateTime.now(); + } + + streamStartedAtSecond ??= newSegments[0].start; + TranscriptSegment.combineSegments( + segments, + newSegments, + toRemoveSeconds: streamStartedAtSecond ?? 0, + toAddSeconds: secondsMissedOnReconnect ?? 0, + ); + triggerTranscriptSegmentReceivedEvents(newSegments, conversationId, sendMessageToChat: (v) { + messageProvider?.addMessage(v); + }); + + debugPrint('Memory creation timer restarted'); + _memoryCreationTimer?.cancel(); + _memoryCreationTimer = + Timer(const Duration(seconds: quietSecondsForMemoryCreation), () => _createPhotoCharacteristicMemory()); + setHasTranscripts(true); + notifyListeners(); + } } diff --git a/app/lib/providers/device_provider.dart b/app/lib/providers/device_provider.dart index 0494f8d59..ce064d659 100644 --- a/app/lib/providers/device_provider.dart +++ b/app/lib/providers/device_provider.dart @@ -3,7 +3,6 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/services/devices.dart'; import 'package:friend_private/services/notification_service.dart'; @@ -13,7 +12,6 @@ import 'package:instabug_flutter/instabug_flutter.dart'; class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption { CaptureProvider? captureProvider; - WebSocketProvider? webSocketProvider; bool isConnecting = false; bool isConnected = false; @@ -25,16 +23,18 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption Timer? _disconnectNotificationTimer; - void setProviders(CaptureProvider provider, WebSocketProvider wsProvider) { + DeviceProvider() { + ServiceManager.instance().device.subscribe(this, this); + } + + void setProviders(CaptureProvider provider) { captureProvider = provider; - webSocketProvider = wsProvider; notifyListeners(); } void setConnectedDevice(BTDeviceStruct? device) { connectedDevice = device; print('setConnectedDevice: $device'); - captureProvider?.updateConnectedDevice(device); notifyListeners(); } @@ -118,7 +118,6 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption } Future scanAndConnectToDevice() async { - ServiceManager.instance().device.subscribe(this, this); updateConnectingStatus(true); if (isConnected) { if (connectedDevice == null) { @@ -149,22 +148,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption if (isConnected) { await initiateBleBatteryListener(); } - await captureProvider?.resetState(restartBytesProcessing: true, btDevice: connectedDevice); - // if (captureProvider?.webSocketConnected == false) { - // restartWebSocket(); - // } - - notifyListeners(); - } - - Future restartWebSocket() async { - debugPrint('restartWebSocket'); - await webSocketProvider?.closeWebSocketWithoutReconnect('Restarting WebSocket'); - if (connectedDevice == null) { - return; - } - await captureProvider?.resetState(restartBytesProcessing: true); notifyListeners(); } @@ -197,7 +181,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(null); setIsConnected(false); updateConnectingStatus(false); - await captureProvider?.resetState(restartBytesProcessing: false); + await captureProvider?.stopStreamDeviceRecording(cleanDevice: true); captureProvider?.setAudioBytesConnected(false); print('after resetState inside initiateConnectionListener'); @@ -224,7 +208,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(device); setIsConnected(true); updateConnectingStatus(false); - await captureProvider?.resetState(restartBytesProcessing: true, btDevice: connectedDevice); + await captureProvider?.streamDeviceRecording(restartBytesProcessing: true, device: device); // initiateBleBatteryListener(); // The device is still disconnected for some reason if (connectedDevice != null) { diff --git a/app/lib/providers/memory_provider.dart b/app/lib/providers/memory_provider.dart index b0b12c02b..e69f0efed 100644 --- a/app/lib/providers/memory_provider.dart +++ b/app/lib/providers/memory_provider.dart @@ -1,8 +1,11 @@ +import 'package:collection/collection.dart'; import 'package:flutter/foundation.dart'; import 'package:friend_private/backend/http/api/memories.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/memory.dart'; +import 'package:friend_private/backend/schema/structured.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; +import 'package:friend_private/utils/features/calendar.dart'; class MemoryProvider extends ChangeNotifier { List memories = []; @@ -88,11 +91,22 @@ class MemoryProvider extends ChangeNotifier { setLoadingMemories(true); var mem = await getMemories(); memories = mem; + createEventsForMemories(); setLoadingMemories(false); notifyListeners(); return memories; } + void createEventsForMemories() { + for (var memory in memories) { + if (memory.structured.events.isNotEmpty && + !memory.structured.events.first.created && + memory.startedAt!.isAfter(DateTime.now().add(const Duration(days: -1)))) { + _handleCalendarCreation(memory); + } + } + } + Future getMoreMemoriesFromServer() async { if (memories.length % 50 != 0) return; if (isLoadingMemories) return; @@ -106,8 +120,34 @@ class MemoryProvider extends ChangeNotifier { void addMemory(ServerMemory memory) { memories.insert(0, memory); - filterMemories(''); + initFilteredMemories(); + notifyListeners(); + } + + int addMemoryWithDate(ServerMemory memory) { + int idx; + var date = memoriesWithDates.indexWhere((element) => + element is DateTime && + element.day == memory.createdAt.day && + element.month == memory.createdAt.month && + element.year == memory.createdAt.year); + if (date != -1) { + var hour = memoriesWithDates[date + 1].createdAt.hour; + var newHour = memory.createdAt.hour; + if (newHour > hour) { + memoriesWithDates.insert(date + 1, memory); + idx = date + 1; + } else { + memoriesWithDates.insert(date + 2, memory); + idx = date + 2; + } + } else { + memoriesWithDates.add(memory.createdAt); + memoriesWithDates.add(memory); + idx = memoriesWithDates.length - 1; + } notifyListeners(); + return idx; } void updateMemory(ServerMemory memory, [int? index]) { @@ -123,6 +163,28 @@ class MemoryProvider extends ChangeNotifier { notifyListeners(); } + _handleCalendarCreation(ServerMemory memory) { + if (!SharedPreferencesUtil().calendarEnabled) return; + if (SharedPreferencesUtil().calendarType != 'auto') return; + + List events = memory.structured.events; + if (events.isEmpty) return; + + List indexes = events.mapIndexed((index, e) => index).toList(); + setMemoryEventsState(memory.id, indexes, indexes.map((_) => true).toList()); + for (var i = 0; i < events.length; i++) { + print('Creating event: ${events[i].title}'); + if (events[i].created) continue; + events[i].created = true; + CalendarUtil().createEvent( + events[i].title, + events[i].startsAt, + events[i].duration, + description: events[i].description, + ); + } + } + ///////////////////////////////////////////////////////////////// ////////// Delete Memory With Undo Functionality /////////////// diff --git a/app/lib/providers/message_provider.dart b/app/lib/providers/message_provider.dart index 7e747cafc..ebb1c8d10 100644 --- a/app/lib/providers/message_provider.dart +++ b/app/lib/providers/message_provider.dart @@ -10,11 +10,25 @@ class MessageProvider extends ChangeNotifier { List messages = []; bool isLoadingMessages = false; + bool hasCachedMessages = false; + bool isClearingChat = false; + + String firstTimeLoadingText = ''; void updatePluginProvider(PluginProvider p) { pluginProvider = p; } + void setHasCachedMessages(bool value) { + hasCachedMessages = value; + notifyListeners(); + } + + void setClearingChat(bool value) { + isClearingChat = value; + notifyListeners(); + } + void setLoadingMessages(bool value) { isLoadingMessages = value; notifyListeners(); @@ -22,25 +36,53 @@ class MessageProvider extends ChangeNotifier { Future refreshMessages() async { setLoadingMessages(true); + if (SharedPreferencesUtil().cachedMessages.isNotEmpty) { + setHasCachedMessages(true); + } messages = await getMessagesFromServer(); if (messages.isEmpty) { messages = SharedPreferencesUtil().cachedMessages; } else { SharedPreferencesUtil().cachedMessages = messages; + setHasCachedMessages(true); } setLoadingMessages(false); notifyListeners(); } + void setMessagesFromCache() { + if (SharedPreferencesUtil().cachedMessages.isNotEmpty) { + setHasCachedMessages(true); + messages = SharedPreferencesUtil().cachedMessages; + } + notifyListeners(); + } + Future> getMessagesFromServer() async { + if (!hasCachedMessages) { + firstTimeLoadingText = 'Reading your memories...'; + notifyListeners(); + } setLoadingMessages(true); var mes = await getMessagesServer(); + if (!hasCachedMessages) { + firstTimeLoadingText = 'Learning from your memories...'; + notifyListeners(); + } messages = mes; setLoadingMessages(false); notifyListeners(); return messages; } + Future clearChat() async { + setClearingChat(true); + var mes = await clearChatServer(); + messages = mes; + setClearingChat(false); + notifyListeners(); + } + void addMessage(ServerMessage message) { messages.insert(0, message); notifyListeners(); diff --git a/app/lib/providers/onboarding_provider.dart b/app/lib/providers/onboarding_provider.dart index 70d7a183d..7ddcd703f 100644 --- a/app/lib/providers/onboarding_provider.dart +++ b/app/lib/providers/onboarding_provider.dart @@ -250,6 +250,14 @@ class OnboardingProvider extends BaseProvider with MessageNotifierMixin implemen return connection?.device; } + Future _getAudioCodec(String deviceId) async { + var connection = await ServiceManager.instance().device.ensureConnection(deviceId); + if (connection == null) { + return BleAudioCodec.pcm8; + } + return connection.getAudioCodec(); + } + @override void dispose() { //TODO: This does not get called when the page is popped diff --git a/app/lib/providers/speech_profile_provider.dart b/app/lib/providers/speech_profile_provider.dart index 7e2acaf5c..8efec243b 100644 --- a/app/lib/providers/speech_profile_provider.dart +++ b/app/lib/providers/speech_profile_provider.dart @@ -10,22 +10,22 @@ import 'package:friend_private/backend/http/cloud_storage.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/backend/schema/memory.dart'; +import 'package:friend_private/backend/schema/message_event.dart'; import 'package:friend_private/backend/schema/structured.dart'; import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/services/devices.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/audio/wav_bytes.dart'; import 'package:friend_private/utils/memories/process.dart'; -import 'package:friend_private/utils/websockets.dart'; +import 'package:friend_private/utils/pure_socket.dart'; import 'package:uuid/uuid.dart'; -class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin implements IDeviceServiceSubsciption { +class SpeechProfileProvider extends ChangeNotifier + with MessageNotifierMixin + implements IDeviceServiceSubsciption, ITransctipSegmentSocketServiceListener { DeviceProvider? deviceProvider; - CaptureProvider? captureProvider; - WebSocketProvider? webSocketProvider; bool? permissionEnabled; bool loading = false; BTDeviceStruct? device; @@ -38,6 +38,8 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp WavBytesUtil audioStorage = WavBytesUtil(codec: BleAudioCodec.opus); StreamSubscription? _bleBytesStream; + TranscripSegmentSocketService? _socket; + bool startedRecording = false; double percentageCompleted = 0; bool uploadingProfile = false; @@ -50,6 +52,9 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp String text = ''; String message = ''; + late bool _isFromOnboarding; + late Function? _finalizedCallback; + /// only used during onboarding ///// String loadingText = 'Uploading your voice profile....'; ServerMemory? memory; @@ -71,10 +76,8 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp notifyListeners(); } - void setProviders(DeviceProvider provider, CaptureProvider captureProvider, WebSocketProvider wsProvider) { + void setProviders(DeviceProvider provider) { deviceProvider = provider; - this.captureProvider = captureProvider; - webSocketProvider = wsProvider; notifyListeners(); } @@ -86,15 +89,15 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp notifyListeners(); } - Future initialise(bool isFromOnboarding) async { + Future initialise(bool isFromOnboarding, {Function? finalizedCallback}) async { + _isFromOnboarding = isFromOnboarding; + _finalizedCallback = finalizedCallback; setInitialising(true); device = deviceProvider?.connectedDevice; - await captureProvider!.resetForSpeechProfile(); - await initiateWebsocket(isFromOnboarding); + await _initiateWebsocket(force: true); - // _bleBytesStream = captureProvider?.bleBytesStream; if (device != null) await initiateFriendAudioStreaming(); - if (webSocketProvider?.wsConnectionState != WebsocketConnectionStatus.connected) { + if (_socket?.state != SocketServiceState.connected) { // wait for websocket to connect await Future.delayed(Duration(seconds: 2)); } @@ -120,96 +123,72 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp ServiceManager.instance().device.subscribe(this, this); } - Future initiateWebsocket(bool isFromOnboarding) async { - await webSocketProvider?.initWebSocket( - codec: BleAudioCodec.opus, - sampleRate: 16000, - includeSpeechProfile: false, - newMemoryWatch: false, - onConnectionSuccess: () { - print('Websocket connected in speech profile'); - notifyListeners(); - }, - onConnectionFailed: (err) { - notifyError('WS_ERR'); - }, - onConnectionClosed: (int? closeCode, String? closeReason) {}, - onConnectionError: (err) { - notifyError('WS_ERR'); - }, - onMessageReceived: (List newSegments) { - if (newSegments.isEmpty) return; - if (segments.isEmpty) { - audioStorage.removeFramesRange(fromSecond: 0, toSecond: newSegments[0].start.toInt()); - } - streamStartedAtSecond ??= newSegments[0].start; - - TranscriptSegment.combineSegments( - segments, - newSegments, - toRemoveSeconds: streamStartedAtSecond ?? 0, - ); - updateProgressMessage(); - _validateSingleSpeaker(); - _handleCompletion(isFromOnboarding); - notifyInfo('SCROLL_DOWN'); - debugPrint('Memory creation timer restarted'); - }, - ); + Future _initiateWebsocket({bool force = false}) async { + _socket = await ServiceManager.instance() + .socket + .speechProfile(codec: BleAudioCodec.opus, sampleRate: 16000, force: force); + if (_socket == null) { + throw Exception("Can not create new speech profile socket"); + } + _socket?.subscribe(this, this); } - _handleCompletion(bool isFromOnboarding) async { + _handleCompletion() async { if (uploadingProfile || profileCompleted) return; String text = segments.map((e) => e.text).join(' ').trim(); int wordsCount = text.split(' ').length; percentageCompleted = (wordsCount / targetWordsCount).clamp(0, 1); notifyListeners(); if (percentageCompleted == 1) { - await finalize(isFromOnboarding); + await finalize(); } notifyListeners(); } - Future finalize(bool isFromOnboarding) async { - if (uploadingProfile || profileCompleted) return; - - int duration = segments.isEmpty ? 0 : segments.last.end.toInt(); - if (duration < 5 || duration > 120) { - notifyError('INVALID_RECORDING'); - } + Future finalize() async { + try { + if (uploadingProfile || profileCompleted) return; - String text = segments.map((e) => e.text).join(' ').trim(); - if (text.split(' ').length < (targetWordsCount / 2)) { - // 25 words - notifyError('TOO_SHORT'); - } - uploadingProfile = true; - notifyListeners(); - await webSocketProvider?.closeWebSocketWithoutReconnect('finalizing'); - forceCompletionTimer?.cancel(); - connectionStateListener?.cancel(); - _bleBytesStream?.cancel(); + int duration = segments.isEmpty ? 0 : segments.last.end.toInt(); + if (duration < 5 || duration > 120) { + notifyError('INVALID_RECORDING'); + } - updateLoadingText('Memorizing your voice...'); - List> raw = List.from(audioStorage.rawPackets); - var data = await audioStorage.createWavFile(filename: 'speaker_profile.wav'); - try { - await uploadProfile(data.item1); - await uploadProfileBytes(raw, duration); - } catch (e) {} - - updateLoadingText('Personalizing your experience...'); - SharedPreferencesUtil().hasSpeakerProfile = true; - if (isFromOnboarding) { - await createMemory(); - captureProvider?.clearTranscripts(); + String text = segments.map((e) => e.text).join(' ').trim(); + if (text.split(' ').length < (targetWordsCount / 2)) { + // 25 words + notifyError('TOO_SHORT'); + } + uploadingProfile = true; + notifyListeners(); + await _socket?.stop(reason: 'finalizing'); + forceCompletionTimer?.cancel(); + connectionStateListener?.cancel(); + _bleBytesStream?.cancel(); + + updateLoadingText('Memorizing your voice...'); + List> raw = List.from(audioStorage.rawPackets); + var data = await audioStorage.createWavFile(filename: 'speaker_profile.wav'); + try { + await uploadProfile(data.item1); + await uploadProfileBytes(raw, duration); + } catch (e) {} + + updateLoadingText('Personalizing your experience...'); + SharedPreferencesUtil().hasSpeakerProfile = true; + if (_isFromOnboarding) { + await createMemory(); + } + uploadingProfile = false; + profileCompleted = true; + text = ''; + updateLoadingText("You're all set!"); + notifyListeners(); + } finally { + if (_finalizedCallback != null) { + _finalizedCallback!(); + } } - await captureProvider?.resetState(restartBytesProcessing: true); - uploadingProfile = false; - profileCompleted = true; - text = ''; - updateLoadingText("You're all set!"); - notifyListeners(); } // TODO: use connection directly @@ -230,9 +209,10 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp onAudioBytesReceived: (List value) { if (value.isEmpty) return; audioStorage.storeFramePacket(value); + value.removeRange(0, 3); - if (webSocketProvider?.wsConnectionState == WebsocketConnectionStatus.connected) { - webSocketProvider?.websocketChannel?.sink.add(value); + if (_socket?.state == SocketServiceState.connected) { + _socket?.send(value); } }, ); @@ -296,8 +276,7 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp percentageCompleted = 0; uploadingProfile = false; profileCompleted = false; - await webSocketProvider?.closeWebSocketWithoutReconnect('closing'); - await captureProvider?.resetState(restartBytesProcessing: true, isFromSpeechProfile: true); + await _socket?.stop(reason: 'closing'); notifyListeners(); } @@ -355,7 +334,10 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp connectionStateListener?.cancel(); _bleBytesStream?.cancel(); forceCompletionTimer?.cancel(); + _finalizedCallback = null; + _socket?.unsubscribe(this); ServiceManager.instance().device.unsubscribe(this); + super.dispose(); } @@ -386,4 +368,39 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp @override void onStatusChanged(DeviceServiceStatus status) {} + + @override + void onClosed() { + // TODO: implement onClosed + } + + @override + void onError(Object err) { + notifyError('WS_ERR'); + } + + @override + void onMessageEventReceived(ServerMessageEvent event) { + // TODO: implement onMessageEventReceived + } + + @override + void onSegmentReceived(List newSegments) { + if (newSegments.isEmpty) return; + if (segments.isEmpty) { + audioStorage.removeFramesRange(fromSecond: 0, toSecond: newSegments[0].start.toInt()); + } + streamStartedAtSecond ??= newSegments[0].start; + + TranscriptSegment.combineSegments( + segments, + newSegments, + toRemoveSeconds: streamStartedAtSecond ?? 0, + ); + updateProgressMessage(); + _validateSingleSpeaker(); + _handleCompletion(); + notifyInfo('SCROLL_DOWN'); + debugPrint('Memory creation timer restarted'); + } } diff --git a/app/lib/providers/websocket_provider.dart b/app/lib/providers/websocket_provider.dart index 87d8e13c3..2d94a4083 100644 --- a/app/lib/providers/websocket_provider.dart +++ b/app/lib/providers/websocket_provider.dart @@ -10,6 +10,7 @@ import 'package:friend_private/utils/websockets.dart'; import 'package:internet_connection_checker_plus/internet_connection_checker_plus.dart'; import 'package:web_socket_channel/io.dart'; +@Deprecated("Use the socket service") class WebSocketProvider with ChangeNotifier { WebsocketConnectionStatus wsConnectionState = WebsocketConnectionStatus.notConnected; bool websocketReconnecting = false; diff --git a/app/lib/services/device_connections.dart b/app/lib/services/device_connections.dart index 29c4298f4..2fe1b24e3 100644 --- a/app/lib/services/device_connections.dart +++ b/app/lib/services/device_connections.dart @@ -45,6 +45,8 @@ abstract class DeviceConnection { DeviceConnectionState get connectionState => _connectionState; + Function(String deviceId, DeviceConnectionState state)? _connectionStateChangedCallback; + DateTime? get pongAt => _pongAt; late StreamSubscription _connectionStateSubscription; @@ -62,8 +64,9 @@ abstract class DeviceConnection { } // Connect + _connectionStateChangedCallback = onConnectionStateChanged; _connectionStateSubscription = bleDevice.connectionState.listen((BluetoothConnectionState state) async { - _onBleConnectionStateChanged(state, onConnectionStateChanged); + _onBleConnectionStateChanged(state); }); await FlutterBluePlus.adapterState.where((val) => val == BluetoothAdapterState.on).first; @@ -82,26 +85,26 @@ abstract class DeviceConnection { _services = await bleDevice.discoverServices(); } - void _onBleConnectionStateChanged( - BluetoothConnectionState state, Function(String deviceId, DeviceConnectionState state)? callback) async { + void _onBleConnectionStateChanged(BluetoothConnectionState state) async { if (state == BluetoothConnectionState.disconnected && _connectionState == DeviceConnectionState.connected) { _connectionState = DeviceConnectionState.disconnected; - await disconnect(callback: callback); + await disconnect(); return; } if (state == BluetoothConnectionState.connected && _connectionState == DeviceConnectionState.disconnected) { _connectionState = DeviceConnectionState.connected; - if (callback != null) { - callback(device.id, _connectionState); + if (_connectionStateChangedCallback != null) { + _connectionStateChangedCallback!(device.id, _connectionState); } } } - Future disconnect({Function(String deviceId, DeviceConnectionState state)? callback}) async { + Future disconnect() async { _connectionState = DeviceConnectionState.disconnected; - if (callback != null) { - callback(device.id, _connectionState); + if (_connectionStateChangedCallback != null) { + _connectionStateChangedCallback!(device.id, _connectionState); + _connectionStateChangedCallback = null; } await bleDevice.disconnect(); _connectionStateSubscription.cancel(); diff --git a/app/lib/services/services.dart b/app/lib/services/services.dart index dc489b15e..127fbeb3d 100644 --- a/app/lib/services/services.dart +++ b/app/lib/services/services.dart @@ -6,10 +6,12 @@ import 'package:flutter/material.dart'; import 'package:flutter_background_service/flutter_background_service.dart'; import 'package:flutter_sound/flutter_sound.dart'; import 'package:friend_private/services/devices.dart'; +import 'package:friend_private/services/sockets.dart'; class ServiceManager { late IMicRecorderService _mic; late IDeviceService _device; + late ISocketService _socket; static ServiceManager? _instance; @@ -19,6 +21,7 @@ class ServiceManager { runner: BackgroundService(), ); sm._device = DeviceService(); + sm._socket = SocketServicePool(); return sm; } @@ -35,6 +38,8 @@ class ServiceManager { IDeviceService get device => _device; + ISocketService get socket => _socket; + static void init() { if (_instance != null) { throw Exception("Service manager is initiated"); diff --git a/app/lib/services/sockets.dart b/app/lib/services/sockets.dart new file mode 100644 index 000000000..4cf3e0a34 --- /dev/null +++ b/app/lib/services/sockets.dart @@ -0,0 +1,77 @@ +import 'package:flutter/material.dart'; +import 'package:friend_private/backend/schema/bt_device.dart'; +import 'package:friend_private/utils/pure_socket.dart'; + +abstract class ISocketService { + void start(); + void stop(); + + Future memory( + {required BleAudioCodec codec, required int sampleRate, bool force = false}); + Future speechProfile( + {required BleAudioCodec codec, required int sampleRate, bool force = false}); +} + +abstract interface class ISocketServiceSubsciption {} + +class SocketServicePool extends ISocketService { + TranscripSegmentSocketService? _socket; + + @override + void start() { + // TODO: implement start + } + + @override + void stop() async { + await _socket?.stop(); + } + + // Warn: Should use a better solution to prevent race conditions + bool mutex = false; + Future socket( + {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { + while (mutex) { + await Future.delayed(const Duration(milliseconds: 50)); + } + mutex = true; + + try { + if (!force && + _socket?.codec == codec && + _socket?.sampleRate == sampleRate && + _socket?.state == SocketServiceState.connected) { + return _socket; + } + + // new socket + await _socket?.stop(); + + _socket = MemoryTranscripSegmentSocketService.create(sampleRate, codec); + await _socket?.start(); + if (_socket?.state != SocketServiceState.connected) { + return null; + } + + return _socket; + } finally { + mutex = false; + } + + return null; + } + + @override + Future memory( + {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { + debugPrint("socket memory > $codec $sampleRate $force"); + return await socket(codec: codec, sampleRate: sampleRate, force: force); + } + + @override + Future speechProfile( + {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { + debugPrint("socket speech profile > $codec $sampleRate $force"); + return await socket(codec: codec, sampleRate: sampleRate, force: force); + } +} diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart new file mode 100644 index 000000000..bf6adfbb5 --- /dev/null +++ b/app/lib/utils/pure_socket.dart @@ -0,0 +1,429 @@ +import 'dart:async'; +import 'dart:convert'; +import 'dart:io'; +import 'dart:math'; + +import 'package:flutter/material.dart'; +import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/backend/schema/bt_device.dart'; +import 'package:friend_private/backend/schema/message_event.dart'; +import 'package:friend_private/backend/schema/transcript_segment.dart'; +import 'package:friend_private/env/env.dart'; +import 'package:friend_private/services/notification_service.dart'; +import 'package:instabug_flutter/instabug_flutter.dart'; +import 'package:internet_connection_checker_plus/internet_connection_checker_plus.dart'; +import 'package:web_socket_channel/io.dart'; +import 'package:web_socket_channel/status.dart' as socket_channel_status; +import 'package:web_socket_channel/web_socket_channel.dart'; + +enum PureSocketStatus { notConnected, connecting, connected, disconnected } + +abstract class IPureSocketListener { + void onMessage(dynamic message); + void onClosed(); + void onError(Object err, StackTrace trace); + + void onInternetConnectionFailed() {} + + void onMaxRetriesReach() {} +} + +abstract class IPureSocket { + Future connect(); + Future disconnect(); + void send(dynamic message); + + void onInternetSatusChanged(InternetStatus status); + + void onMessage(dynamic message); + void onClosed(); + void onError(Object err, StackTrace trace); +} + +class PureSocketMessage { + String? raw; +} + +class PureCore { + late InternetConnection internetConnection; + + factory PureCore() => _instance; + + /// The singleton instance of [PureCore]. + static final _instance = PureCore.createInstance(); + + PureCore.createInstance() { + internetConnection = InternetConnection.createInstance( + /* + customCheckOptions: [ + InternetCheckOption( + uri: Uri.parse(Env.apiBaseUrl!), + timeout: const Duration( + seconds: 30, + ), + responseStatusFn: (resp) { + return resp.statusCode < 500; + }, + ), + ], + */ + ); + } +} + +class PureSocket implements IPureSocket { + StreamSubscription? _internetStatusListener; + InternetStatus? _internetStatus; + Timer? _internetLostDelayTimer; + + WebSocketChannel? _channel; + WebSocketChannel get channel { + if (_channel == null) { + throw Exception('Socket is not connected'); + } + return _channel!; + } + + PureSocketStatus _status = PureSocketStatus.notConnected; + PureSocketStatus get status => _status; + + IPureSocketListener? _listener; + + int _retries = 0; + + String url; + + PureSocket(this.url) { + _internetStatusListener = PureCore().internetConnection.onStatusChange.listen((InternetStatus status) { + onInternetSatusChanged(status); + }); + } + + void setListener(IPureSocketListener listener) { + _listener = listener; + } + + @override + Future connect() async { + return await _connect(); + } + + Future _connect() async { + if (_status == PureSocketStatus.connecting || _status == PureSocketStatus.connected) { + return false; + } + + _channel = IOWebSocketChannel.connect( + url, + pingInterval: const Duration(seconds: 10), + connectTimeout: const Duration(seconds: 30), + ); + if (_channel?.ready == null) { + return false; + } + + _status = PureSocketStatus.connecting; + dynamic err; + try { + await channel.ready; + } on SocketException catch (e) { + err = e; + } on WebSocketChannelException catch (e) { + err = e; + } + if (err != null) { + print("Error: $err"); + _status = PureSocketStatus.notConnected; + return false; + } + _status = PureSocketStatus.connected; + _retries = 0; + + final that = this; + + _channel?.stream.listen( + (message) { + that.onMessage(message); + }, + onError: (err, trace) { + that.onError(err, trace); + }, + onDone: () { + that.onClosed(); + }, + cancelOnError: true, + ); + + return true; + } + + @override + Future disconnect() async { + if (_status == PureSocketStatus.connected) { + // Warn: should not use await cause dead end by socket closed. + _channel?.sink.close(socket_channel_status.normalClosure); + } + _status = PureSocketStatus.disconnected; + onClosed(); + } + + Future _cleanUp() async { + _internetLostDelayTimer?.cancel(); + _internetStatusListener?.cancel(); + } + + Future stop() async { + await disconnect(); + await _cleanUp(); + } + + @override + void onClosed() { + _status = PureSocketStatus.disconnected; + debugPrint("Socket closed"); + _listener?.onClosed(); + } + + @override + void onError(Object err, StackTrace trace) { + _status = PureSocketStatus.disconnected; + print("Error: ${err}"); + debugPrintStack(stackTrace: trace); + + _listener?.onError(err, trace); + + CrashReporting.reportHandledCrash(err, trace, level: NonFatalExceptionLevel.error); + } + + @override + void onMessage(dynamic message) { + debugPrint("[Socket] Message $message"); + _listener?.onMessage(message); + } + + @override + void send(message) { + _channel?.sink.add(message); + } + + void _reconnect() async { + debugPrint("[Socket] reconnect...${_retries + 1}..."); + const int initialBackoffTimeMs = 1000; // 1 second + const double multiplier = 1.5; + const int maxRetries = 7; + + if (_status == PureSocketStatus.connecting || _status == PureSocketStatus.connected) { + debugPrint("[Socket] Can not reconnect, because socket is $_status"); + return; + } + + await _cleanUp(); + + var ok = await _connect(); + if (ok) { + return; + } + + // retry + int waitInMilliseconds = pow(multiplier, _retries).toInt() * initialBackoffTimeMs; + await Future.delayed(Duration(milliseconds: waitInMilliseconds)); + _retries++; + if (_retries >= maxRetries) { + debugPrint("[Socket] Reach max retries $maxRetries"); + _listener?.onMaxRetriesReach(); + return; + } + _reconnect(); + } + + @override + void onInternetSatusChanged(InternetStatus status) { + debugPrint("[Socket] Internet connection changed $status socket $_status"); + _internetStatus = status; + switch (status) { + case InternetStatus.connected: + if (_status == PureSocketStatus.connected || _status == PureSocketStatus.connecting) { + return; + } + _reconnect(); + break; + case InternetStatus.disconnected: + var that = this; + _internetLostDelayTimer?.cancel(); + _internetLostDelayTimer = Timer(const Duration(seconds: 60), () async { + if (_internetStatus != InternetStatus.disconnected) { + return; + } + + await that.disconnect(); + _listener?.onInternetConnectionFailed(); + }); + + break; + } + } +} + +abstract interface class ITransctipSegmentSocketServiceListener { + void onMessageEventReceived(ServerMessageEvent event); + void onSegmentReceived(List segments); + void onError(Object err); + void onClosed(); +} + +class SpeechProfileTranscripSegmentSocketService extends TranscripSegmentSocketService { + SpeechProfileTranscripSegmentSocketService.create(super.sampleRate, super.codec) + : super.create(includeSpeechProfile: false, newMemoryWatch: false); +} + +class MemoryTranscripSegmentSocketService extends TranscripSegmentSocketService { + MemoryTranscripSegmentSocketService.create(super.sampleRate, super.codec) + : super.create(includeSpeechProfile: true, newMemoryWatch: true); +} + +enum SocketServiceState { + connected, + disconnected, +} + +class TranscripSegmentSocketService implements IPureSocketListener { + late PureSocket _socket; + final Map _listeners = {}; + + SocketServiceState get state => + _socket.status == PureSocketStatus.connected ? SocketServiceState.connected : SocketServiceState.disconnected; + + int sampleRate; + BleAudioCodec codec; + bool includeSpeechProfile; + bool newMemoryWatch; + + TranscripSegmentSocketService.create( + this.sampleRate, + this.codec, { + this.includeSpeechProfile = false, + this.newMemoryWatch = true, + }) { + final recordingsLanguage = SharedPreferencesUtil().recordingsLanguage; + var params = '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}' + '&include_speech_profile=$includeSpeechProfile&new_memory_watch=$newMemoryWatch&stt_service=${SharedPreferencesUtil().transcriptionModel}'; + String url = '${Env.apiBaseUrl!.replaceAll('https', 'wss')}listen$params'; + + _socket = PureSocket(url); + _socket.setListener(this); + } + + void subscribe(Object context, ITransctipSegmentSocketServiceListener listener) { + _listeners.remove(context.hashCode); + _listeners.putIfAbsent(context.hashCode, () => listener); + } + + void unsubscribe(Object context) { + _listeners.remove(context.hashCode); + } + + Future start() async { + bool ok = await _socket.connect(); + if (!ok) { + debugPrint("Can not connect to websocket"); + } + } + + Future stop({String? reason}) async { + await _socket.stop(); + _listeners.clear(); + + if (reason != null) { + debugPrint(reason); + } + } + + Future send(dynamic message) async { + _socket.send(message); + return; + } + + @override + void onClosed() { + _listeners.forEach((k, v) { + v.onClosed(); + }); + } + + @override + void onError(Object err, StackTrace trace) { + _listeners.forEach((k, v) { + v.onError(err); + }); + } + + @override + void onMessage(event) { + debugPrint("[TranscriptSegmentService] onMessage ${event}"); + if (event == 'ping') return; + + // Decode json + dynamic jsonEvent; + try { + jsonEvent = jsonDecode(event); + } on FormatException catch (e) { + debugPrint(e.toString()); + } + if (jsonEvent == null) { + debugPrint("Can not decode message event json $event"); + return; + } + + // Transcript segments + if (jsonEvent is List) { + var segments = jsonEvent; + if (segments.isEmpty) { + return; + } + _listeners.forEach((k, v) { + v.onSegmentReceived(segments.map((e) => TranscriptSegment.fromJson(e)).toList()); + }); + return; + } + + debugPrint(event); + + // Message event + if (jsonEvent.containsKey("type")) { + var event = ServerMessageEvent.fromJson(jsonEvent); + _listeners.forEach((k, v) { + v.onMessageEventReceived(event); + }); + return; + } + + debugPrint(event.toString()); + } + + @override + void onInternetConnectionFailed() { + debugPrint("onInternetConnectionFailed"); + + // Send notification + NotificationService.instance.clearNotification(3); + NotificationService.instance.createNotification( + notificationId: 3, + title: 'Internet Connection Lost', + body: 'Your device is offline. Transcription is paused until connection is restored.', + ); + } + + @override + void onMaxRetriesReach() { + debugPrint("onMaxRetriesReach"); + + // Send notification + NotificationService.instance.clearNotification(2); + NotificationService.instance.createNotification( + notificationId: 2, + title: 'Connection Issue 🚨', + body: 'Unable to connect to the transcript service.' + ' Please restart the app or contact support if the problem persists.', + ); + } +} diff --git a/app/lib/widgets/photos_grid.dart b/app/lib/widgets/photos_grid.dart index 4cb620a97..276f5e9e0 100644 --- a/app/lib/widgets/photos_grid.dart +++ b/app/lib/widgets/photos_grid.dart @@ -1,48 +1,53 @@ import 'dart:convert'; import 'package:flutter/material.dart'; +import 'package:friend_private/pages/memory_detail/memory_detail_provider.dart'; import 'package:friend_private/widgets/dialog.dart'; +import 'package:provider/provider.dart'; import 'package:tuple/tuple.dart'; class PhotosGridComponent extends StatelessWidget { - final List> photos; - const PhotosGridComponent({super.key, required this.photos}); + const PhotosGridComponent({super.key}); @override Widget build(BuildContext context) { - return GridView.builder( - padding: EdgeInsets.zero, - shrinkWrap: true, - scrollDirection: Axis.vertical, - itemCount: photos.length, - physics: const NeverScrollableScrollPhysics(), - itemBuilder: (context, idx) { - return GestureDetector( - onTap: () { - showDialog( - context: context, - builder: (c) { - return getDialog( - context, - () => Navigator.pop(context), - () => Navigator.pop(context), - 'Description', - photos[idx].item2, - singleButton: true, - ); - }); - }, - child: Container( - decoration: BoxDecoration(border: Border.all(color: Colors.grey.shade600, width: 0.5)), - child: Image.memory(base64Decode(photos[idx].item1), fit: BoxFit.cover), - ), - ); - }, - gridDelegate: const SliverGridDelegateWithFixedCrossAxisCount( - crossAxisCount: 3, - crossAxisSpacing: 8, - mainAxisSpacing: 8, - ), - ); + return Selector>>( + selector: (context, provider) => provider.photosData, + builder: (context, photos, child) { + return GridView.builder( + padding: EdgeInsets.zero, + shrinkWrap: true, + scrollDirection: Axis.vertical, + itemCount: photos.length, + physics: const NeverScrollableScrollPhysics(), + itemBuilder: (context, idx) { + return GestureDetector( + onTap: () { + showDialog( + context: context, + builder: (c) { + return getDialog( + context, + () => Navigator.pop(context), + () => Navigator.pop(context), + 'Description', + photos[idx].item2, + singleButton: true, + ); + }); + }, + child: Container( + decoration: BoxDecoration(border: Border.all(color: Colors.grey.shade600, width: 0.5)), + child: Image.memory(base64Decode(photos[idx].item1), fit: BoxFit.cover), + ), + ); + }, + gridDelegate: const SliverGridDelegateWithFixedCrossAxisCount( + crossAxisCount: 3, + crossAxisSpacing: 8, + mainAxisSpacing: 8, + ), + ); + }); } } diff --git a/app/pubspec.yaml b/app/pubspec.yaml index 4990f937a..4c568deb1 100644 --- a/app/pubspec.yaml +++ b/app/pubspec.yaml @@ -3,7 +3,7 @@ description: A new Flutter project. publish_to: 'none' # Remove this line if you wish to publish to pub.dev -version: 1.0.38+134 +version: 1.0.39+138 environment: sdk: ">=3.0.0 <4.0.0" diff --git a/backend/database/chat.py b/backend/database/chat.py index 3c4c13329..dcf55bffe 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -59,6 +59,10 @@ def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: b # Fetch messages and collect memory IDs for doc in messages_ref.stream(): message = doc.to_dict() + + if message.get('deleted') is True: + continue + messages.append(message) memories_id.update(message.get('memories_id', [])) @@ -82,3 +86,45 @@ def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: b ] return messages + + +def batch_delete_messages(parent_doc_ref, batch_size=450): + messages_ref = parent_doc_ref.collection('messages') + last_doc = None # For pagination + + while True: + if last_doc: + docs = messages_ref.limit(batch_size).start_after(last_doc).stream() + else: + docs = messages_ref.limit(batch_size).stream() + + docs_list = list(docs) + + if not docs_list: + print("No more messages to delete") + break + + batch = db.batch() + + for doc in docs_list: + batch.update(doc.reference, {'deleted': True}) + + batch.commit() + + if len(docs_list) < batch_size: + print("Processed all messages") + break + + last_doc = docs_list[-1] + + +def clear_chat(uid: str): + try: + user_ref = db.collection('users').document(uid) + print(f"Deleting messages for user: {uid}") + if not user_ref.get().exists: + return {"message": "User not found"} + batch_delete_messages(user_ref) + return None + except Exception as e: + return {"message": str(e)} \ No newline at end of file diff --git a/backend/database/memories.py b/backend/database/memories.py index f12ca4257..7bc30f07e 100644 --- a/backend/database/memories.py +++ b/backend/database/memories.py @@ -56,11 +56,13 @@ def delete_memory(uid, memory_id): def filter_memories_by_date(uid, start_date, end_date): + # TODO: check utc comparison or not? user_ref = db.collection('users').document(uid) query = ( user_ref.collection('memories') .where(filter=FieldFilter('created_at', '>=', start_date)) .where(filter=FieldFilter('created_at', '<=', end_date)) + .where(filter=FieldFilter('deleted', '==', False)) .where(filter=FieldFilter('discarded', '==', False)) .order_by('created_at', direction=firestore.Query.DESCENDING) ) @@ -85,7 +87,10 @@ def get_memories_by_id(uid, memory_ids): memories = [] for doc in docs: if doc.exists: - memories.append(doc.to_dict()) + data = doc.to_dict() + if data.get('deleted') or data.get('discarded'): + continue + memories.append(data) return memories @@ -117,11 +122,13 @@ def get_memory_transcripts_by_model(uid: str, memory_id: str): memory_ref = user_ref.collection('memories').document(memory_id) deepgram_ref = memory_ref.collection('deepgram_streaming') soniox_ref = memory_ref.collection('soniox_streaming') + speechmatics_ref = memory_ref.collection('speechmatics_streaming') whisperx_ref = memory_ref.collection('fal_whisperx') return { 'deepgram': list(sorted([doc.to_dict() for doc in deepgram_ref.stream()], key=lambda x: x['start'])), 'soniox': list(sorted([doc.to_dict() for doc in soniox_ref.stream()], key=lambda x: x['start'])), + 'speechmatics': list(sorted([doc.to_dict() for doc in speechmatics_ref.stream()], key=lambda x: x['start'])), 'whisperx': list(sorted([doc.to_dict() for doc in whisperx_ref.stream()], key=lambda x: x['start'])), } @@ -131,6 +138,11 @@ def update_memory_events(uid: str, memory_id: str, events: List[dict]): memory_ref = user_ref.collection('memories').document(memory_id) memory_ref.update({'structured.events': events}) +def update_memory_finished_at(uid: str, memory_id: str, finished_at: datetime): + user_ref = db.collection('users').document(uid) + memory_ref = user_ref.collection('memories').document(memory_id) + memory_ref.update({'finished_at': finished_at}) + # VISBILITY diff --git a/backend/database/processing_memories.py b/backend/database/processing_memories.py index 7d839e2f1..db1a4bd76 100644 --- a/backend/database/processing_memories.py +++ b/backend/database/processing_memories.py @@ -32,6 +32,10 @@ def get_processing_memories_by_id(uid, processing_memory_ids): memories.append(doc.to_dict()) return memories +def get_processing_memory_by_id(uid, processing_memory_id): + memory_ref = db.collection('users').document(uid).collection('processing_memories').document(processing_memory_id) + return memory_ref.get().to_dict() + def update_processing_memory_segments(uid: str, id: str, segments: List[dict]): user_ref = db.collection('users').document(uid) memory_ref = user_ref.collection('processing_memories').document(id) diff --git a/backend/database/vector_db.py b/backend/database/vector_db.py index 26bf92157..eb18ae053 100644 --- a/backend/database/vector_db.py +++ b/backend/database/vector_db.py @@ -42,7 +42,7 @@ def upsert_vectors( print('upsert_vectors', res) -def query_vectors(query: str, uid: str, starts_at: int = None, ends_at: int = None, k:int = 5) -> List[str]: +def query_vectors(query: str, uid: str, starts_at: int = None, ends_at: int = None, k: int = 5) -> List[str]: filter_data = {'uid': uid} if starts_at is not None: filter_data['created_at'] = {'$gte': starts_at, '$lte': ends_at} @@ -55,4 +55,6 @@ def query_vectors(query: str, uid: str, starts_at: int = None, ends_at: int = No def delete_vector(memory_id: str): - index.delete(ids=[memory_id], namespace="ns1") + # TODO: does this work? + result = index.delete(ids=[memory_id], namespace="ns1") + print('delete_vector', result) diff --git a/backend/main.py b/backend/main.py index ab21cdb0a..c3fa83b06 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,8 +6,7 @@ from modal import Image, App, asgi_app, Secret, Cron from routers import workflow, chat, firmware, plugins, memories, transcribe, notifications, speech_profile, \ - agents, facts, users, postprocessing, processing_memories, trends,sdcard - + agents, facts, users, postprocessing, processing_memories, trends, sdcard from utils.other.notifications import start_cron_job if os.environ.get('SERVICE_ACCOUNT_JSON'): @@ -70,3 +69,30 @@ def api(): @modal_app.function(image=image, schedule=Cron('* * * * *')) async def notifications_cronjob(): await start_cron_job() + + +@app.post('/webhook') +async def webhook(data: dict): + diarization = data['output']['diarization'] + joined = [] + for speaker in diarization: + if not joined: + joined.append(speaker) + else: + if speaker['speaker'] == joined[-1]['speaker']: + joined[-1]['end'] = speaker['end'] + else: + joined.append(speaker) + + print(data['jobId'], json.dumps(joined)) + # openn scripts/stt/diarization.json, get jobId=memoryId, delete but get memoryId, and save memoryId=joined + with open('scripts/stt/diarization.json', 'r') as f: + diarization_data = json.loads(f.read()) + + memory_id = diarization_data.get(data['jobId']) + if memory_id: + diarization_data[memory_id] = joined + del diarization_data[data['jobId']] + with open('scripts/stt/diarization.json', 'w') as f: + json.dump(diarization_data, f, indent=2) + return 'ok' diff --git a/backend/modal/speech_profile_modal.py b/backend/modal/speech_profile_modal.py index 89ed38513..521b3be68 100644 --- a/backend/modal/speech_profile_modal.py +++ b/backend/modal/speech_profile_modal.py @@ -130,8 +130,8 @@ def endpoint(uid: str, audio_file: UploadFile = File(...), segments: str = Form( segments_data = json.loads(segments) transcript_segments = [TranscriptSegment(**segment) for segment in segments_data] - people = get_people_with_speech_samples(uid) - + # people = get_people_with_speech_samples(uid) + people = [] try: result = classify_segments(audio_file.filename, profile_path, people, transcript_segments) print(result) diff --git a/backend/models/processing_memory.py b/backend/models/processing_memory.py index 7eb1f6656..11d812068 100644 --- a/backend/models/processing_memory.py +++ b/backend/models/processing_memory.py @@ -14,6 +14,7 @@ class ProcessingMemory(BaseModel): audio_url: Optional[str] = None created_at: datetime timer_start: float + timer_segment_start: Optional[float] = None timer_starts: List[float] = [] language: Optional[str] = None # applies only to Friend # TODO: once released migrate db to default 'en' transcript_segments: List[TranscriptSegment] = [] @@ -23,6 +24,13 @@ class ProcessingMemory(BaseModel): memory_id: Optional[str] = None message_ids: List[str] = [] +class BasicProcessingMemory(BaseModel): + id: str + timer_start: float + created_at: datetime + geolocation: Optional[Geolocation] = None + emotional_feedback: Optional[bool] = False + class UpdateProcessingMemory(BaseModel): id: Optional[str] = None @@ -31,4 +39,4 @@ class UpdateProcessingMemory(BaseModel): class UpdateProcessingMemoryResponse(BaseModel): - result: ProcessingMemory + result: BasicProcessingMemory diff --git a/backend/models/transcript_segment.py b/backend/models/transcript_segment.py index 6fbb99d0a..617de0dd2 100644 --- a/backend/models/transcript_segment.py +++ b/backend/models/transcript_segment.py @@ -42,6 +42,46 @@ def can_display_seconds(segments): return False return True + @staticmethod + def combine_segments(segments: [], new_segments: [], delta_seconds: int = 0): + if not new_segments or len(new_segments) == 0: + return segments + + joined_similar_segments = [] + for new_segment in new_segments: + if delta_seconds > 0: + new_segment.start += delta_seconds + new_segment.end += delta_seconds + + if (joined_similar_segments and + (joined_similar_segments[-1].speaker == new_segment.speaker or + (joined_similar_segments[-1].is_user and new_segment.is_user))): + joined_similar_segments[-1].text += f' {new_segment.text}' + joined_similar_segments[-1].end = new_segment.end + else: + joined_similar_segments.append(new_segment) + + if (segments and + (segments[-1].speaker == joined_similar_segments[0].speaker or + (segments[-1].is_user and joined_similar_segments[0].is_user)) and + (joined_similar_segments[0].start - segments[-1].end < 30)): + segments[-1].text += f' {joined_similar_segments[0].text}' + segments[-1].end = joined_similar_segments[0].end + joined_similar_segments.pop(0) + + segments.extend(joined_similar_segments) + + # Speechmatics specific issue with punctuation + for i, segment in enumerate(segments): + segments[i].text = ( + segments[i].text.strip() + .replace(' ', '') + .replace(' ,', ',') + .replace(' .', '.') + .replace(' ?', '?') + ) + return segments + class ImprovedTranscriptSegment(BaseModel): speaker_id: int = Field(..., description='The correctly assigned speaker id') diff --git a/backend/requirements.txt b/backend/requirements.txt index 91ea5daa1..8d86dd6e0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,6 +9,7 @@ altair==5.4.0 annotated-types==0.7.0 antlr4-python3-runtime==4.9.3 anyio==4.4.0 +assemblyai==0.33.0 asteroid-filterbanks==0.4.0 asttokens==2.4.1 attrs==24.1.0 @@ -33,11 +34,14 @@ decorator==5.1.1 deepgram-sdk==3.4.0 deprecation==2.1.0 distro==1.9.0 +dnspython==2.6.1 docopt==0.6.2 einops==0.8.0 +email_validator==2.2.0 executing==2.0.1 fal_client==0.4.1 fastapi==0.112.0 +fastapi-cli==0.0.5 fastapi-utilities==0.2.0 filelock==3.15.4 firebase-admin==6.5.0 @@ -60,12 +64,14 @@ googleapis-common-protos==1.63.2 groq==0.9.0 grpcio==1.65.4 grpcio-status==1.62.3 +grpcio-tools==1.62.3 grpclib==0.4.7 h11==0.14.0 h2==4.1.0 hpack==4.0.0 httpcore==1.0.5 httplib2==0.22.0 +httptools==0.6.1 httpx==0.27.0 httpx-sse==0.4.0 huggingface-hub==0.24.5 @@ -76,6 +82,7 @@ idna==3.7 ipython==8.26.0 jedi==0.19.1 Jinja2==3.1.4 +jiwer==3.0.4 joblib==1.4.2 jsonpatch==1.33 jsonpointer==3.0.0 @@ -105,6 +112,7 @@ matplotlib-inline==0.1.7 mdurl==0.1.2 modal==0.64.7 monotonic==1.6 +more-itertools==10.5.0 mplcursors==0.5.3 mpld3==0.5.10 mpmath==1.3.0 @@ -121,6 +129,7 @@ omegaconf==2.3.0 onnxruntime==1.19.0 openai==1.39.0 optuna==3.6.1 +opuslib==3.0.1 orjson==3.10.6 packaging==24.1 pandas==2.2.2 @@ -132,6 +141,7 @@ pinecone-plugin-inference==1.0.3 pinecone-plugin-interface==0.0.7 platformdirs==4.2.2 plotly==5.23.0 +polling2==0.5.0 pooch==1.8.2 portalocker==2.10.1 posthog==3.5.2 @@ -157,6 +167,7 @@ pydub==0.25.1 Pygments==2.18.0 PyJWT==2.9.0 pynndescent==0.5.13 +PyOgg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f234e8a3a346c4874a12509bfa02c4c63a pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 @@ -166,6 +177,7 @@ pytorch-metric-learning==2.6.0 pytz==2024.1 PyYAML==6.0.1 qdrant-client==1.11.0 +rapidfuzz==3.9.7 redis==5.0.8 referencing==0.35.1 regex==2024.7.24 @@ -184,10 +196,13 @@ sigtools==4.0.1 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 +soniox==1.10.1 sortedcontainers==2.4.0 +SoundCard==0.4.3 soundfile==0.12.1 soxr==0.4.0 speechbrain==1.0.0 +speechmatics-python==2.0.1 SQLAlchemy==2.0.32 stack-data==0.6.3 starlette==0.37.2 @@ -195,7 +210,7 @@ streamlit==1.37.1 sympy==1.13.1 synchronicity==0.6.7 tabulate==0.9.0 -tenacity==8.5.0 +tenacity==8.2.3 tensorboardX==2.6.2.2 threadpoolctl==3.5.0 tiktoken==0.7.0 @@ -220,9 +235,9 @@ umap-learn==0.5.6 uritemplate==4.1.1 urllib3==2.2.2 uvicorn==0.30.5 +uvloop==0.20.0 watchdog==4.0.2 watchfiles==0.22.0 wcwidth==0.2.13 websockets==12.0 yarl==1.9.4 -pyogg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 7b427fb36..089d560f2 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -2,10 +2,11 @@ from datetime import datetime, timezone from typing import List, Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException import database.chat as chat_db from models.chat import Message, SendMessageRequest, MessageSender +from utils.chat.chat import clear_user_chat_message from utils.llm import qa_rag, initial_chat_message from utils.other import endpoints as auth from utils.plugins import get_plugin_by_id @@ -29,7 +30,8 @@ def filter_messages(messages, plugin_id): def send_message( data: SendMessageRequest, plugin_id: Optional[str] = None, uid: str = Depends(auth.get_current_user_uid) ): - message = Message(id=str(uuid.uuid4()), text=data.text, created_at=datetime.now(timezone.utc), sender='human', type='text') + message = Message(id=str(uuid.uuid4()), text=data.text, created_at=datetime.now(timezone.utc), sender='human', + type='text') chat_db.add_message(uid, message.dict()) plugin = get_plugin_by_id(plugin_id) @@ -56,6 +58,14 @@ def send_message( return ai_message +@router.delete('/v1/clear-chat', tags=['chat'], response_model=Message) +def clear_chat(uid: str = Depends(auth.get_current_user_uid)): + err = clear_user_chat_message(uid) + if err: + raise HTTPException(status_code=500, detail='Failed to clear chat') + return initial_message_util(uid) + + def initial_message_util(uid: str, plugin_id: Optional[str] = None): plugin = get_plugin_by_id(plugin_id) text = initial_chat_message(uid, plugin) diff --git a/backend/routers/memories.py b/backend/routers/memories.py index 8599ac194..5d7913819 100644 --- a/backend/routers/memories.py +++ b/backend/routers/memories.py @@ -113,6 +113,7 @@ def get_memory_transcripts_by_models(memory_id: str, uid: str = Depends(auth.get @router.delete("/v1/memories/{memory_id}", status_code=204, tags=['memories']) def delete_memory(memory_id: str, uid: str = Depends(auth.get_current_user_uid)): + print('delete_memory', memory_id, uid) memories_db.delete_memory(uid, memory_id) delete_vector(memory_id) return {"status": "Ok"} diff --git a/backend/routers/postprocessing.py b/backend/routers/postprocessing.py index 7567764cb..55eac26e7 100644 --- a/backend/routers/postprocessing.py +++ b/backend/routers/postprocessing.py @@ -1,22 +1,9 @@ -import asyncio -import os -import threading -import time from fastapi import APIRouter, Depends, HTTPException, UploadFile -from pydub import AudioSegment -import database.memories as memories_db -from database.users import get_user_store_recording_permission from models.memory import * -from routers.memories import _get_memory_by_id -from utils.memories.process_memory import process_memory, process_user_emotion +from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util from utils.other import endpoints as auth -from utils.other.storage import upload_postprocessing_audio, \ - delete_postprocessing_audio, upload_memory_recording -from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing -from utils.stt.speech_profile import get_speech_profile_matching_predictions -from utils.stt.vad import vad_is_empty router = APIRouter() @@ -39,223 +26,17 @@ def postprocess_memory( TODO: should consider storing non beautified segments, and beautify on read? TODO: post llm process here would be great, sometimes whisper x outputs without punctuation """ - memory_data = _get_memory_by_id(uid, memory_id) - memory = Memory(**memory_data) - if memory.discarded: - print('postprocess_memory: Memory is discarded') - raise HTTPException(status_code=400, detail="Memory is discarded") - if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: - print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') - raise HTTPException(status_code=400, detail="Memory can't be post-processed again") + # TODO: this pipeline vs groq+pyannote diarization 3.1, probably the latter is better. + # Save file file_path = f"_temp/{memory_id}_{file.filename}" with open(file_path, 'wb') as f: f.write(file.file.read()) - aseg = AudioSegment.from_wav(file_path) - if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 - # TODO: fix app, sometimes audio uploaded is wrong, is too short. - print('postprocess_memory: Audio duration is too short, seems wrong.') - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) - raise HTTPException(status_code=500, detail="Audio duration is too short, seems wrong.") + # Process + status_code, result = postprocess_memory_util(memory_id=memory_id, uid=uid, file_path=file_path, emotional_feedback=emotional_feedback, streaming_model="deepgram_streaming") + if status_code != 200: + raise HTTPException(status_code=status_code, detail=result) - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) - - try: - # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. - vad_segments = vad_is_empty(file_path, return_segments=True) - if vad_segments: - start = vad_segments[0]['start'] - end = vad_segments[-1]['end'] - aseg = AudioSegment.from_wav(file_path) - aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] - aseg.export(file_path, format="wav") - except Exception as e: - print(e) - - try: - aseg = AudioSegment.from_wav(file_path) - signed_url = upload_postprocessing_audio(file_path) - threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() - - if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): - upload_memory_recording(file_path, uid, memory_id) - - speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) - words = fal_whisperx(signed_url, speakers_count) - fal_segments = fal_postprocessing(words, aseg.duration_seconds) - - # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL - count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) - new_count = len(''.join([segment.text.strip() for segment in fal_segments])) - print('Prev characters count:', count, 'New characters count:', new_count) - - fal_failed = not fal_segments or new_count < (count * 0.85) - - if fal_failed: - _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) - else: - _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) - - # Store both models results. - memories_db.store_model_segments_result(uid, memory.id, 'deepgram_streaming', memory.transcript_segments) - memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) - - if not fal_failed: - memory.transcript_segments = fal_segments - - memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later - if fal_failed: - # TODO: FAL fails too much and is fucking expensive. Remove it. - fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) - memory.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx, fail_reason=fail_reason, - ) - # TODO: consider doing process_memory, if any segment still matched to user or people - return memory - - # Reprocess memory with improved transcription - result: Memory = process_memory(uid, memory.language, memory, force_process=True) - - # Process users emotion, async - if emotional_feedback: - asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) - except Exception as e: - print(e) - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) - raise HTTPException(status_code=500, detail=str(e)) - - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) - result.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) return result - - -# TODO: Move to util -def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): - """ - The objective of this endpoint, is to get the best possible transcript from the audio file. - Instead of storing the initial deepgram result, doing a full post-processing with whisper-x. - This increases the quality of transcript by at least 20%. - Which also includes a better summarization. - Which helps us create better vectors for the memory. - And improves the overall experience of the user. - - TODO: Try Nvidia Nemo ASR as suggested by @jhonnycombs https://huggingface.co/spaces/hf-audio/open_asr_leaderboard - That + pyannote diarization 3.1, is as good as it gets. Then is only hardware improvements. - TODO: should consider storing non beautified segments, and beautify on read? - TODO: post llm process here would be great, sometimes whisper x outputs without punctuation - """ - memory_data = _get_memory_by_id(uid, memory_id) - memory = Memory(**memory_data) - if memory.discarded: - print('postprocess_memory: Memory is discarded') - return 400, "Memory is discarded" - - if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: - print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') - return 400, "Memory can't be post-processed again" - - aseg = AudioSegment.from_wav(file_path) - if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 - # TODO: fix app, sometimes audio uploaded is wrong, is too short. - print('postprocess_memory: Audio duration is too short, seems wrong.') - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) - return (500, "Audio duration is too short, seems wrong.") - - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) - - try: - # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. - vad_segments = vad_is_empty(file_path, return_segments=True) - if vad_segments: - start = vad_segments[0]['start'] - end = vad_segments[-1]['end'] - aseg = AudioSegment.from_wav(file_path) - aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] - aseg.export(file_path, format="wav") - except Exception as e: - print(e) - - try: - aseg = AudioSegment.from_wav(file_path) - signed_url = upload_postprocessing_audio(file_path) - threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() - - if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): - upload_memory_recording(file_path, uid, memory_id) - - speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) - words = fal_whisperx(signed_url, speakers_count) - fal_segments = fal_postprocessing(words, aseg.duration_seconds) - - # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL - count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) - new_count = len(''.join([segment.text.strip() for segment in fal_segments])) - print('Prev characters count:', count, 'New characters count:', new_count) - - fal_failed = not fal_segments or new_count < (count * 0.85) - - if fal_failed: - _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) - else: - _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) - - # Store both models results. - memories_db.store_model_segments_result(uid, memory.id, streaming_model, memory.transcript_segments) - memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) - - if not fal_failed: - memory.transcript_segments = fal_segments - - memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later - if fal_failed: - # TODO: FAL fails too much and is fucking expensive. Remove it. - fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) - memory.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx, fail_reason=fail_reason, - ) - # TODO: consider doing process_memory, if any segment still matched to user or people - return (200, memory) - - # Reprocess memory with improved transcription - result: Memory = process_memory(uid, memory.language, memory, force_process=True) - - # Process users emotion, async - if emotional_feedback: - asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) - except Exception as e: - print(e) - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) - return 500, str(e) - - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) - result.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) - - return 200, result - - -def _delete_postprocessing_audio(file_path): - time.sleep(300) # 5 min - delete_postprocessing_audio(file_path) - os.remove(file_path) - - -async def _process_user_emotion(uid: str, language_code: str, memory: Memory, urls: [str]): - if not any(segment.is_user for segment in memory.transcript_segments): - print(f"_process_user_emotion skipped for {memory.id}") - return - - process_user_emotion(uid, language_code, memory, urls) - - -def _handle_segment_embedding_matching(uid: str, file_path: str, segments: List[TranscriptSegment], aseg: AudioSegment): - if aseg.frame_rate == 16000: - matches = get_speech_profile_matching_predictions(uid, file_path, [s.dict() for s in segments]) - for i, segment in enumerate(segments): - segment.is_user = matches[i]['is_user'] - segment.person_id = matches[i].get('person_id') diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 3d582c5b4..ddeb68f06 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -1,3 +1,4 @@ +import math import threading import uuid from datetime import datetime, timezone @@ -7,16 +8,16 @@ import requests from fastapi import APIRouter from fastapi.websockets import WebSocketDisconnect, WebSocket +from pydub import AudioSegment from starlette.websockets import WebSocketState import database.memories as memories_db import database.processing_memories as processing_memories_db -from database.redis_db import get_user_speech_profile from models.memory import Memory, TranscriptSegment from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated from models.processing_memory import ProcessingMemory -from routers.postprocessing import postprocess_memory_util from utils.audio import create_wav_from_bytes, merge_wav_files +from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util from utils.memories.process_memory import process_memory from utils.other.storage import upload_postprocessing_audio from utils.processing_memories import create_memory_by_processing_memory @@ -67,40 +68,22 @@ # def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: # -def _combine_segments(segments: [], new_segments: [], delta_seconds: int = 0): - if not new_segments or len(new_segments) == 0: - return segments - - joined_similar_segments = [] - for new_segment in new_segments: - if delta_seconds > 0: - new_segment.start += delta_seconds - new_segment.end += delta_seconds - - if (joined_similar_segments and - (joined_similar_segments[-1].speaker == new_segment.speaker or - (joined_similar_segments[-1].is_user and new_segment.is_user))): - joined_similar_segments[-1].text += f' {new_segment.text}' - joined_similar_segments[-1].end = new_segment.end - else: - joined_similar_segments.append(new_segment) - - if (segments and - (segments[-1].speaker == joined_similar_segments[0].speaker or - (segments[-1].is_user and joined_similar_segments[0].is_user)) and - (joined_similar_segments[0].start - segments[-1].end < 30)): - segments[-1].text += f' {joined_similar_segments[0].text}' - segments[-1].end = joined_similar_segments[0].end - joined_similar_segments.pop(0) - - segments.extend(joined_similar_segments) - - return segments - class STTService(str, Enum): deepgram = "deepgram" soniox = "soniox" + speechmatics = "speechmatics" + + # auto = "auto" + + @staticmethod + def get_model_name(value): + if value == STTService.deepgram: + return 'deepgram_streaming' + elif value == STTService.soniox: + return 'soniox_streaming' + elif value == STTService.speechmatics: + return 'speechmatics_streaming' async def _websocket_util( @@ -108,13 +91,18 @@ async def _websocket_util( channels: int = 1, include_speech_profile: bool = True, new_memory_watch: bool = False, stt_service: STTService = STTService.deepgram, ): - print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch) + print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, + stt_service) + + if stt_service == STTService.soniox and language not in soniox_valid_languages: + stt_service = STTService.deepgram # defaults to deepgram - if stt_service == STTService.soniox and ( - sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages): + if stt_service == STTService.speechmatics: # defaults to deepgram (no credits + 10 connections max limit) stt_service = STTService.deepgram - # Check: Why do we need try-catch around websocket.accept? + # TODO: if language english, use soniox + # TODO: else deepgram, if speechmatics credits, prob this for both? + try: await websocket.accept() except RuntimeError as e: @@ -141,8 +129,8 @@ async def _websocket_util( speech_profile_stream_id = 2 loop = asyncio.get_event_loop() - # Soft timeout - timeout_seconds = 420 # 7m ~ MODAL_TIME_OUT - 3m + # Soft timeout, should < MODAL_TIME_OUT - 3m + timeout_seconds = 420 # 7m started_at = time.time() def stream_transcript(segments, stream_id): @@ -150,10 +138,27 @@ def stream_transcript(segments, stream_id): nonlocal processing_memory nonlocal processing_memory_synced nonlocal memory_transcript_segements + nonlocal segment_start + nonlocal segment_end if not segments or len(segments) == 0: return + # Align the start, end segment + if not segment_start: + start = segments[0]["start"] + segment_start = start + + # end + end = segments[-1]["end"] + if not segment_end or segment_end < end: + segment_end = end + + for i, segment in enumerate(segments): + segment["start"] -= segment_start + segment["end"] -= segment_start + segments[i] = segment + asyncio.run_coroutine_threadsafe(websocket.send_json(segments), loop) threading.Thread(target=process_segments, args=(uid, segments)).start() @@ -163,7 +168,7 @@ def stream_transcript(segments, stream_id): delta_seconds = 0 if processing_memory and processing_memory.timer_start > 0: delta_seconds = timer_start - processing_memory.timer_start - memory_transcript_segements = _combine_segments( + memory_transcript_segements = TranscriptSegment.combine_segments( memory_transcript_segements, list(map(lambda m: TranscriptSegment(**m), segments)), delta_seconds ) @@ -182,41 +187,54 @@ def stream_audio(audio_buffer): processing_audio_frames.append(audio_buffer) soniox_socket = None + speechmatics_socket = None deepgram_socket = None deepgram_socket2 = None websocket_active = True websocket_close_code = 1001 # Going Away, don't close with good from backend timer_start = None + segment_start = None + segment_end = None + audio_frames_per_sec = 100 # audio_buffer = None duration = 0 try: - # Soniox + file_path, duration = None, 0 + # TODO: how bee does for recognizing other languages speech profile + if language == 'en' and (codec == 'opus' or codec == 'pcm16') and include_speech_profile: + file_path = get_profile_audio_if_exists(uid) + print(f'deepgram-obns3: file_path {file_path}') + duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0 + + # DEEPGRAM if stt_service == STTService.deepgram: - if language == 'en' and codec == 'opus' and include_speech_profile: - second_per_frame: float = 0.01 - speech_profile = get_user_speech_profile(uid) - duration = len(speech_profile) * second_per_frame - print('speech_profile', len(speech_profile), duration) - if duration: - duration += 10 - else: - speech_profile, duration = [], 0 - deepgram_socket = await process_audio_dg( - stream_transcript, memory_stream_id, language, sample_rate, codec, channels, preseconds=duration + stream_transcript, memory_stream_id, language, sample_rate, channels, preseconds=duration ) if duration: deepgram_socket2 = await process_audio_dg( - stream_transcript, speech_profile_stream_id, language, sample_rate, codec, channels + stream_transcript, speech_profile_stream_id, language, sample_rate, channels ) - await send_initial_file(speech_profile, deepgram_socket) + print(f'deepgram-obns3: send_initial_file_path > deepgram_socket {deepgram_socket}') + async def deepgram_socket_send(data): + return deepgram_socket.send(data) + await send_initial_file_path(file_path, deepgram_socket_send) + # SONIOX elif stt_service == STTService.soniox: soniox_socket = await process_audio_soniox( - stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None + stream_transcript, speech_profile_stream_id, sample_rate, language, + uid if include_speech_profile else None ) - + # SPEECHMATICS + elif stt_service == STTService.speechmatics: + speechmatics_socket = await process_audio_speechmatics( + stream_transcript, speech_profile_stream_id, sample_rate, language, preseconds=duration + ) + if duration: + await send_initial_file_path(file_path, speechmatics_socket.send) + print('speech_profile speechmatics duration', duration) except Exception as e: print(f"Initial processing error: {e}") @@ -229,7 +247,7 @@ def stream_audio(audio_buffer): decoder = opuslib.Decoder(sample_rate, channels) - async def receive_audio(dg_socket1, dg_socket2, soniox_socket): + async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_socket1): nonlocal websocket_active nonlocal websocket_close_code nonlocal timer_start @@ -246,12 +264,17 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket): # audio_file = open(path, "a") try: while websocket_active: - data = await websocket.receive_bytes() + raw_data = await websocket.receive_bytes() + data = raw_data[:] # audio_buffer.extend(data) + if codec == 'opus' and sample_rate == 16000: + data = decoder.decode(bytes(data), frame_size=160) if soniox_socket is not None: - decoded_opus = decoder.decode(bytes(data), frame_size=160) - await soniox_socket.send(decoded_opus) + await soniox_socket.send(data) + + if speechmatics_socket1 is not None: + await speechmatics_socket1.send(data) if deepgram_socket is not None: elapsed_seconds = time.time() - timer_start @@ -265,7 +288,7 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket): dg_socket2.send(data) # stream - stream_audio(data) + stream_audio(raw_data) # audio_buffer = bytearray() @@ -273,6 +296,8 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket): print("WebSocket disconnected") except Exception as e: print(f'Could not process audio: error {e}') + print(f'deepgram-obns3: receive_audio > dg_socket1 {dg_socket1}') + print(f'deepgram-obns3: receive_audio > dg_socket2 {dg_socket2}') websocket_close_code = 1011 finally: websocket_active = False @@ -282,6 +307,8 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket): dg_socket2.finish() if soniox_socket: await soniox_socket.close() + if speechmatics_socket: + await speechmatics_socket.close() # heart beat async def send_heartbeat(): @@ -332,10 +359,11 @@ async def _create_processing_memory(): last_processing_memory_data = processing_memories_db.get_last(uid) if last_processing_memory_data: last_processing_memory = ProcessingMemory(**last_processing_memory_data) - segment_end = 0 + last_segment_end = 0 for segment in last_processing_memory.transcript_segments: - segment_end = max(segment_end, segment.end) - if last_processing_memory.timer_start + segment_end + min_seconds_limit > time.time(): + last_segment_end = max(last_segment_end, segment.end) + timer_segment_start = last_processing_memory.timer_segment_start if last_processing_memory.timer_segment_start else last_processing_memory.timer_start + if timer_segment_start + last_segment_end + min_seconds_limit > time.time(): processing_memory = last_processing_memory # Or create new @@ -344,6 +372,7 @@ async def _create_processing_memory(): id=str(uuid.uuid4()), created_at=datetime.now(timezone.utc), timer_start=timer_start, + timer_segment_start=timer_start + segment_start, language=language, ) @@ -355,7 +384,7 @@ async def _create_processing_memory(): processing_memory.timer_starts.append(timer_start) # Transcript with delta - memory_transcript_segements = _combine_segments( + memory_transcript_segements = TranscriptSegment.combine_segments( processing_memory.transcript_segments, memory_transcript_segements, timer_start - processing_memory.timer_start ) @@ -376,11 +405,22 @@ async def _post_process_memory(memory: Memory): nonlocal processing_memory nonlocal processing_audio_frames nonlocal processing_audio_frame_synced + nonlocal segment_start + nonlocal segment_end # Create wav processing_audio_frame_synced = len(processing_audio_frames) + + # Remove audio frames [start, end] + left = 0 + if segment_start: + left = max(left, math.floor(segment_start) * audio_frames_per_sec) + right = processing_audio_frame_synced + if segment_end: + right = min(math.ceil(segment_end) * audio_frames_per_sec, right) + file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" - create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[:processing_audio_frame_synced], + create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[left:right], frame_rate=sample_rate, channels=channels, codec=codec, ) # Try merge new audio with the previous @@ -394,7 +434,9 @@ async def _post_process_memory(memory: Memory): # merge merge_file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" - merge_wav_files(merge_file_path, [previous_file_path, file_path]) + nearest_timer_start = processing_memory.timer_starts[-2] + merge_wav_files(merge_file_path, [previous_file_path, file_path], + [math.ceil(timer_start - nearest_timer_start), 0]) # clean os.remove(previous_file_path) @@ -411,8 +453,7 @@ async def _post_process_memory(memory: Memory): # Process emotional_feedback = processing_memory.emotional_feedback status, new_memory = postprocess_memory_util( - memory.id, file_path, uid, emotional_feedback, - 'deepgram_streaming' if stt_service == STTService.deepgram else 'soniox_streaming' + memory.id, file_path, uid, emotional_feedback, STTService.get_model_name(stt_service) ) if status == 200: memory = new_memory @@ -442,11 +483,11 @@ async def _create_memory(): await _create_processing_memory() else: # or ensure synced processing transcript - processing_memories = processing_memories_db.get_processing_memories_by_id(uid, [processing_memory.id]) - if len(processing_memories) == 0: + processing_memory_data = processing_memories_db.get_processing_memory_by_id(uid, processing_memory.id) + if not processing_memory_data: print("processing memory is not found") return - processing_memory = ProcessingMemory(**processing_memories[0]) + processing_memory = ProcessingMemory(**processing_memory_data) processing_memory_synced = len(memory_transcript_segements) processing_memory.transcript_segments = memory_transcript_segements[:processing_memory_synced] @@ -464,8 +505,8 @@ async def _create_memory(): memory = None messages = [] if not processing_memory.memory_id: - (new_memory, new_messages, updated_processing_memory) = await create_memory_by_processing_memory(uid, - processing_memory.id) + new_memory, new_messages, updated_processing_memory = await create_memory_by_processing_memory( + uid, processing_memory.id) if not new_memory: print("Can not create new memory") @@ -491,6 +532,11 @@ async def _create_memory(): memories_db.update_memory_segments(uid, memory.id, [segment.dict() for segment in memory.transcript_segments]) + # Update finished at + memory.finished_at = datetime.fromtimestamp( + memory.started_at.timestamp() + processing_memory.transcript_segments[-1].end, timezone.utc) + memories_db.update_memory_finished_at(uid, memory.id, memory.finished_at) + # Process memory = process_memory(uid, memory.language, memory, force_process=True) @@ -527,6 +573,8 @@ async def _try_flush_new_memory_with_lock(time_validate: bool = True): async def _try_flush_new_memory(time_validate: bool = True): nonlocal memory_transcript_segements nonlocal timer_start + nonlocal segment_start + nonlocal segment_end nonlocal processing_memory nonlocal processing_memory_synced nonlocal processing_audio_frames @@ -537,13 +585,8 @@ async def _try_flush_new_memory(time_validate: bool = True): return # Validate last segment - last_segment = None - if len(memory_transcript_segements) > 0: - last_segment = memory_transcript_segements[-1] - if not last_segment: + if not segment_end: print("Not last segment or last segment invalid") - if last_segment: - print(f"{last_segment.dict()}") return # First chunk, create processing memory @@ -554,7 +597,6 @@ async def _try_flush_new_memory(time_validate: bool = True): # Validate transcript # Longer 120s - segment_end = last_segment.end now = time.time() should_create_memory_time = True if time_validate: @@ -588,7 +630,8 @@ async def _try_flush_new_memory(time_validate: bool = True): processing_memory = None try: - receive_task = asyncio.create_task(receive_audio(deepgram_socket, deepgram_socket2, soniox_socket)) + receive_task = asyncio.create_task( + receive_audio(deepgram_socket, deepgram_socket2, soniox_socket, speechmatics_socket)) heartbeat_task = asyncio.create_task(send_heartbeat()) # Run task diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index 7f26a6914..61c06888a 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -1,17 +1,581 @@ # STEPS -# - get all users -# - get all memories non discarded -# - filter memories with audio recording available -# - filter again by ones that have whisperx + deepgram segments or soniox segments -# - store local json with data - -# - P2 +# - Download all files. +# - get each of those memories # - read local json with each memory audio file # - call whisper groq (whisper-largev3) # - Create a table df, well printed, with each transcript result side by side # - prompt for computing WER using groq whisper as baseline (if better, but most likely) # - Run for deepgram vs soniox, and generate comparison result - +import asyncio # - P3 # - Include speechmatics to the game +import json +import os +import re +from collections import defaultdict +from itertools import islice +from typing import Dict, List + +import firebase_admin +import requests +from dotenv import load_dotenv +from pydub import AudioSegment +from tabulate import tabulate + +load_dotenv('../../.dev.env') +os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS') +firebase_admin.initialize_app() + +from models.transcript_segment import TranscriptSegment +from utils.stt.streaming import process_audio_dg, process_audio_soniox, process_audio_speechmatics +from groq import Groq + +from utils.other.storage import upload_postprocessing_audio +from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing + + +def add_model_result_segments(model: str, new_segments: List[Dict], result: Dict): + segments = [TranscriptSegment(**s) for s in result[model]] + new_segments = [TranscriptSegment(**s) for s in new_segments] + segments = TranscriptSegment.combine_segments(segments, new_segments) + result[model] = [s.dict() for s in segments] + + +def execute_groq(file_path: str): + file_size = os.path.getsize(file_path) + print('execute_groq file_size', file_size / 1024 / 1024, 'MB') + split_files = [] + if file_size / 1024 / 1024 > 25: + # split file + aseg = AudioSegment.from_wav(file_path) + # split every 10 minutes + split_duration = 10 * 60 * 1000 + for i in range(0, len(aseg), split_duration): + split_file_path = f'{file_path}_{i}.wav' + split_files.append(split_file_path) + aseg[i:i + split_duration].export(split_file_path, format="wav") + else: + split_files.append(file_path) + + client = Groq(api_key=os.getenv('GROQ_API_KEY')) + result = '' + for file_path in split_files: + with open(file_path, "rb") as file: + transcription = client.audio.transcriptions.create( + file=(file_path, file.read()), + model="whisper-large-v3", + response_format="text", + language="en", + temperature=0.0 + ) + result += ' ' + str(transcription) + return result.strip().lower().replace(' ', ' ') + + +async def _execute_single(file_path: str): + aseg = AudioSegment.from_wav(file_path) + duration = aseg.duration_seconds + memory_id = file_path.split('/')[-1].split('.')[0] + + if os.path.exists(f'results/{memory_id}.json'): + print('Already processed', memory_id) + return + if aseg.duration_seconds < 5: + print('Skipping', memory_id, 'duration', aseg.duration_seconds) + return + + print('Started processing', memory_id, 'duration', aseg.duration_seconds) + result = { + 'deepgram': [], + 'soniox': [], + 'speechmatics': [] + } + + def stream_transcript_deepgram(new_segments, _): + print('stream_transcript_deepgram', new_segments) + add_model_result_segments('deepgram', new_segments, result) + + def stream_transcript_soniox(new_segments, _): + print('stream_transcript_soniox', new_segments) + add_model_result_segments('soniox', new_segments, result) + + def stream_transcript_speechmatics(new_segments, _): + print('stream_transcript_speechmatics', new_segments) + add_model_result_segments('speechmatics', new_segments, result) + + # streaming models + socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0) + socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 16000, 'en', None) + socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 16000, 'en', 0) + print('duration', duration) + with open(file_path, "rb") as file: + while True: + chunk = file.read(320) + if not chunk: + break + socket.send(bytes(chunk)) + await socket_soniox.send(bytes(chunk)) + await socket_speechmatics.send(bytes(chunk)) + await asyncio.sleep(0.005) + + print('Finished sending audio') + groq_result: str = execute_groq(file_path) # source of truth + result['whisper-large-v3'] = groq_result + + # whisperx + try: + signed_url = upload_postprocessing_audio(file_path) + words = fal_whisperx(signed_url) + fal_segments = fal_postprocessing(words, duration) + result['fal_whisperx'] = [s.dict() for s in fal_segments] + except Exception as e: + print('fal_whisperx', e) + result['fal_whisperx'] = [] + + print('Waiting for sockets to finish', min(60, duration), 'seconds') + await asyncio.sleep(min(30, duration)) + + os.makedirs('results', exist_ok=True) + with open(f'results/{memory_id}.json', 'w') as f: + json.dump(result, f, indent=2) + + socket.finish() + await socket_soniox.close() + await socket_speechmatics.close() + + +def batched(iterable, n): + """ + Generator that yields lists of size 'n' from 'iterable'. + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +async def process_memories_audio_files(): + uids = os.listdir('_temp2') + for uid in uids: + memories = os.listdir(f'_temp2/{uid}') + memories = [f'_temp2/{uid}/{memory}' for memory in memories] + # batch_size = 5 + for memory in memories: + await _execute_single(memory) + # for batch_num, batch in enumerate(batched(memories, batch_size), start=1): + # tasks = [asyncio.create_task(_execute_single(file_path)) for file_path in batch] + # await asyncio.gather(*tasks) + # print(f'Batch {batch_num} processed') + + +from jiwer import wer + + +def compute_wer(): + """ + Computes the Word Error Rate (WER) for each transcription model against a reference model + across all JSON files in the specified directory. Outputs detailed results and overall rankings. + """ + dir_path = 'results/' # Directory containing JSON files + reference_model = 'whisper-large-v3' # Reference model key + table_data = [] # List to hold detailed table rows + wer_accumulator = defaultdict(list) # To accumulate WERs per model + points_counter = defaultdict(int) # To count points per model based on WER rankings + + # Define detailed table headers + detailed_headers = [ + "File", + "Model", + "WER", + "Source Words", + "Model Words", + "Source Characters", + "Model Characters", + "Transcript" + ] + + # Check if the directory exists + if not os.path.isdir(dir_path): + print(f"Directory '{dir_path}' does not exist.") + return + + # Iterate through all files in the specified directory + for file in os.listdir(dir_path): + if not file.endswith('.json'): + continue # Skip non-JSON files + + file_path = os.path.join(dir_path, file) + with open(file_path, 'r', encoding='utf-8') as f: + try: + result = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in file: {file}") + continue # Skip files with invalid JSON + + # Check if the reference model exists in the JSON + if reference_model not in result: + print(f"Reference model '{reference_model}' not found in file: {file}") + continue # Skip files without the reference model + + # Assemble the reference transcript + reference_text = regex_fix(result.get(reference_model, '')) + if isinstance(reference_text, list): + # If reference_text is a list of segments + reference_text = ' '.join([segment.get('text', '') for segment in reference_text]).strip().lower() + else: + # If reference_text is a single string + reference_text = str(reference_text).strip().lower() + reference_text = ' '.join(reference_text.split()) # Normalize whitespace + + # Calculate source words and characters + source_words = len(reference_text.split()) + source_characters = len(reference_text) + + print(f"Processing file: {file}") + + # Temporary storage for current file's model WERs to determine ranking points + current_file_wer = {} + + # Iterate through each model in the JSON + for model, segments in result.items(): + if model == reference_model: + model_text = reference_text # Reference model's transcript + else: + if isinstance(segments, list): + # Assemble the model's transcript from segments + model_text = ' '.join([segment.get('text', '') for segment in segments]).strip().lower() + else: + # If segments is a single string + model_text = str(segments).strip().lower() + model_text = ' '.join(model_text.split()) # Normalize whitespace + + # Calculate model words and characters + model_words = len(model_text.split()) + model_characters = len(model_text) + + # Compute WER + current_wer = wer(reference_text, model_text) + + # Accumulate WER for overall statistics (exclude reference model) + if model != reference_model: + wer_accumulator[model].append(current_wer) + + # Store WER for current file's ranking + if model != reference_model: + current_file_wer[model] = current_wer + + # Append the data to the detailed table + table_data.append([ + file, + model, + f"{current_wer:.2%}", + source_words, + model_words, + source_characters, + model_characters, + model_text + ]) + + # Determine which model(s) had the lowest WER in the current file + if current_file_wer: + min_wer = min(current_file_wer.values()) + best_models = [model for model, w in current_file_wer.items() if w == min_wer] + for model in best_models: + points_counter[model] += 1 # Assign 1 point to each best model + + print('-----------------------------------------') + + # Generate the detailed WER table using tabulate + if table_data: + print("\nDetailed WER Results:") + detailed_table = tabulate(table_data, headers=detailed_headers, tablefmt="grid", stralign="left") + with open('results/detailed_wer.txt', 'w') as f: + f.write(detailed_table) + else: + print("No data to display.") + + # Compute overall WER per model (average) + overall_wer = {} + for model, wer_list in wer_accumulator.items(): + if wer_list: + overall_wer[model] = sum(wer_list) / len(wer_list) + + # Create a list for overall WER table + overall_wer_table = [] + for model, avg_wer in overall_wer.items(): + overall_wer_table.append([ + model, + f"{avg_wer:.2%}" + ]) + + # Sort the overall WER table by average WER ascending (lower is better) + overall_wer_table_sorted = sorted(overall_wer_table, key=lambda x: x[1]) + + # Define overall WER table headers + overall_wer_headers = ["Model", "Average WER"] + + # Generate the overall WER table + if overall_wer_table_sorted: + print("\nOverall WER per Model:") + overall_wer_formatted = tabulate(overall_wer_table_sorted, headers=overall_wer_headers, tablefmt="grid", + stralign="left") + print(overall_wer_formatted) + with open('results/wer.txt', 'w') as f: + f.write(overall_wer_formatted) + else: + print("No overall WER data to display.") + + # Create a ranking table based on points + ranking_table = [] + for model, points in points_counter.items(): + ranking_table.append([ + model, + points + ]) + + # Sort the ranking table by points descending (more points are better) + ranking_table_sorted = sorted(ranking_table, key=lambda x: x[1], reverse=True) + + # Assign rankings + ranking_table_with_rank = [] + current_rank = 1 + previous_points = None + for idx, (model, points) in enumerate(ranking_table_sorted): + if points != previous_points: + rank = current_rank + else: + rank = current_rank - 1 # Same rank as previous + ranking_table_with_rank.append([ + rank, + model, + points + ]) + previous_points = points + current_rank += 1 + + # Define ranking table headers + ranking_headers = ["Rank", "Model", "Points"] + + # Generate the ranking table + if ranking_table_with_rank: + print("\nModel Rankings Based on WER Performance:") + ranking_table_formatted = tabulate(ranking_table_with_rank, headers=ranking_headers, tablefmt="grid", + stralign="left") + print(ranking_table_formatted) + with open('results/ranking.txt', 'w') as f: + f.write(ranking_table_formatted) + else: + print("No ranking data to display.") + + +def regex_fix(text: str): + """Fix some of the stored JSON in results/$id.json from the Groq API.""" + pattern = r'(?<=transcription\(text=["\'])(.*?)(?=["\'],\s*task=)' + match = re.search(pattern, text) + if match: + extracted_text = match.group(0) + return extracted_text + else: + print("No match found.") + return text + + +def pyannote_diarize(file_path: str): + memory_id = file_path.split('/')[-1].split('.')[0] + with open('diarization.json', 'r') as f: + results = json.loads(f.read()) + + if memory_id in results: + print('Already diarized', memory_id) + return + + url = "https://api.pyannote.ai/v1/diarize" + headers = {"Authorization": f"Bearer {os.getenv('PYANNOTE_API_KEY')}"} + webhook = 'https://camel-lucky-reliably.ngrok-free.app/webhook' + signed_url = upload_postprocessing_audio(file_path) + data = {'webhook': webhook, 'url': signed_url} + response = requests.post(url, headers=headers, json=data) + print(memory_id, response.json()['jobId']) + # update diarization.json, and set jobId=memoryId + with open('diarization.json', 'r') as f: + diarization = json.loads(f.read()) + + diarization[response.json()['jobId']] = memory_id + with open('diarization.json', 'w') as f: + json.dump(diarization, f, indent=2) + + +def generate_diarizations(): + uids = os.listdir('_temp2') + for uid in uids: + memories = os.listdir(f'_temp2/{uid}') + memories = [f'_temp2/{uid}/{memory}' for memory in memories] + for memory in memories: + memory_id = memory.split('/')[-1].split('.')[0] + if os.path.exists(f'results/{memory_id}.json'): + pyannote_diarize(memory) + else: + print('Skipping', memory_id) + + +from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.core import Annotation, Segment + +der_metric = DiarizationErrorRate() + + +def compute_der(): + """ + Computes the Diarization Error Rate (DER) for each model across all JSON files in the 'results/' directory. + Outputs a summary table and rankings to 'der_report.txt'. + """ + dir_path = 'results/' # Directory containing result JSON files and 'diarization.json' + output_file = os.path.join(dir_path, 'der_report.txt') # Output report file + excluded_model = 'whisper-large-v3' # Model to exclude from analysis + + # Initialize DER metric + der_metric = DiarizationErrorRate() + + # Check if the directory exists + if not os.path.isdir(dir_path): + print(f"Directory '{dir_path}' does not exist.") + return + + # Path to 'diarization.json' + diarization_path = 'diarization.json' + + # Load reference diarization data + with open(diarization_path, 'r', encoding='utf-8') as f: + try: + diarization = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in 'diarization.json'.") + return + + # Prepare to collect DER results + der_results = [] # List to store [Memory ID, Model, DER] + model_der_accumulator = defaultdict(list) # To calculate average DER per model + + # Iterate through all JSON files in 'results/' directory + for file in os.listdir(dir_path): + if not file.endswith('.json') or file == 'diarization.json': + continue # Skip non-JSON files and 'diarization.json' itself + + memory_id = file.split('.')[0] # Extract memory ID from filename + + # Check if memory_id exists in 'diarization.json' + if memory_id not in diarization: + print(f"Memory ID '{memory_id}' not found in 'diarization.json'. Skipping file: {file}") + continue + + # Load reference segments for the current memory_id + ref_segments = diarization[memory_id] + ref_annotation = Annotation() + for seg in ref_segments: + speaker, start, end = seg['speaker'], seg['start'], seg['end'] + ref_annotation[Segment(start, end)] = speaker + + # Load hypothesis segments from the result JSON file + file_path = os.path.join(dir_path, file) + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in file: {file}. Skipping.") + continue + + # Iterate through each model's segments in the result + for model, segments in data.items(): + if model == excluded_model: + continue # Skip the excluded model + + hyp_annotation = Annotation() + for seg in segments: + speaker, start, end = seg['speaker'], seg['start'], seg['end'] + # Optional: Normalize speaker labels if necessary + if speaker == 'SPEAKER_0': + speaker = 'SPEAKER_00' + elif speaker == 'SPEAKER_1': + speaker = 'SPEAKER_01' + elif speaker == 'SPEAKER_2': + speaker = 'SPEAKER_02' + elif speaker == 'SPEAKER_3': + speaker = 'SPEAKER_03' + hyp_annotation[Segment(start, end)] = speaker + + # Compute DER between reference and hypothesis + der = der_metric(ref_annotation, hyp_annotation) + + # Store the result + der_results.append([memory_id, model, f"{der:.2%}"]) + model_der_accumulator[model].append(der) + + # Generate the detailed DER table + der_table = tabulate(der_results, headers=["Memory ID", "Model", "DER"], tablefmt="grid", stralign="left") + + # Calculate average DER per model + average_der = [] + for model, ders in model_der_accumulator.items(): + avg = sum(ders) / len(ders) + average_der.append([model, f"{avg:.2%}"]) + + # Sort models by average DER ascending (lower is better) + average_der_sorted = sorted(average_der, key=lambda x: float(x[1].strip('%'))) + + # Determine the winner (model with the lowest average DER) + winner = average_der_sorted[0][0] if average_der_sorted else "N/A" + + # Prepare rankings (1st, 2nd, etc.) + rankings = [] + rank = 1 + previous_der = None + for model, avg in average_der_sorted: + current_der = float(avg.strip('%')) + if previous_der is None or current_der < previous_der: + current_rank = rank + else: + current_rank = rank - 1 # Same rank as previous if DER is equal + rankings.append([current_rank, model, avg]) + previous_der = current_der + rank += 1 + + # Generate the rankings table + ranking_table = tabulate(rankings, headers=["Rank", "Model", "Average DER"], tablefmt="grid", stralign="left") + + # Write all results to the output file + with open(output_file, 'w', encoding='utf-8') as out_f: + out_f.write("Diarization Error Rate (DER) Analysis Report\n") + out_f.write("=" * 50 + "\n\n") + out_f.write("Detailed DER Results:\n") + out_f.write(der_table + "\n\n") + out_f.write("Average DER per Model:\n") + out_f.write( + tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left") + "\n\n") + out_f.write("Model Rankings Based on Average DER:\n") + out_f.write(ranking_table + "\n\n") + out_f.write(f"Winner: {winner}\n") + + # Print a confirmation message + print(f"Diarization Error Rate (DER) analysis completed. Report saved to '{output_file}'.") + + # Optionally, print the tables to the console as well + if der_results: + print("\nDetailed DER Results:") + print(der_table) + if average_der_sorted: + print("\nAverage DER per Model:") + print(tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left")) + if rankings: + print("\nModel Rankings Based on Average DER:") + print(ranking_table) + print(f"\nWinner: {winner}") + + +if __name__ == '__main__': + # asyncio.run(process_memories_audio_files()) + # generate_diarizations() + # compute_wer() + compute_der() diff --git a/backend/utils/audio.py b/backend/utils/audio.py index 069a2ebdd..bf4663c73 100644 --- a/backend/utils/audio.py +++ b/backend/utils/audio.py @@ -3,14 +3,16 @@ from pyogg import OpusDecoder from pydub import AudioSegment -def merge_wav_files(dest_file_path: str, source_files: [str]): +def merge_wav_files(dest_file_path: str, source_files: [str], silent_seconds: [int]): if len(source_files) == 0 or not dest_file_path: return combined_sounds = AudioSegment.empty() - for file_path in source_files: + for i in range(len(source_files)): + file_path = source_files[i] sound = AudioSegment.from_wav(file_path) - combined_sounds = combined_sounds + sound + silent_sec = silent_seconds[i] + combined_sounds = combined_sounds + sound + AudioSegment.silent(duration=silent_sec) combined_sounds.export(dest_file_path, format="wav") diff --git a/backend/utils/chat/chat.py b/backend/utils/chat/chat.py new file mode 100644 index 000000000..8be3036e7 --- /dev/null +++ b/backend/utils/chat/chat.py @@ -0,0 +1,6 @@ +import database.chat as chat_db + + +def clear_user_chat_message(uid: str): + err = chat_db.clear_chat(uid) + return err diff --git a/backend/utils/memories/postprocess_memory.py b/backend/utils/memories/postprocess_memory.py new file mode 100644 index 000000000..be9762f8a --- /dev/null +++ b/backend/utils/memories/postprocess_memory.py @@ -0,0 +1,143 @@ +import asyncio +import os +import threading +import time + +from pydub import AudioSegment + +import database.memories as memories_db +from database.users import get_user_store_recording_permission +from models.memory import * +from utils.memories.process_memory import process_memory, process_user_emotion +from utils.other.storage import upload_postprocessing_audio, \ + delete_postprocessing_audio, upload_memory_recording +from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing +from utils.stt.speech_profile import get_speech_profile_matching_predictions +from utils.stt.vad import vad_is_empty + + +def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): + memory_data = _get_memory_by_id(uid, memory_id) + if not memory_data: + return 404, "Memory not found" + + memory = Memory(**memory_data) + if memory.discarded: + print('postprocess_memory: Memory is discarded') + return 400, "Memory is discarded" + + if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: + print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') + return 400, "Memory can't be post-processed again" + + aseg = AudioSegment.from_wav(file_path) + if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 + # TODO: fix app, sometimes audio uploaded is wrong, is too short. + print('postprocess_memory: Audio duration is too short, seems wrong.') + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) + return 500, "Audio duration is too short, seems wrong." + + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) + + try: + # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. + # TODO: use this logs to determine if whisperx is failing because of the VAD results. + print('previous to vad_is_empty (segments duration):', + memory.transcript_segments[-1].end - memory.transcript_segments[0].start) + vad_segments = vad_is_empty(file_path, return_segments=True) + if vad_segments: + start = vad_segments[0]['start'] + end = vad_segments[-1]['end'] + print('vad_is_empty file result segments:', start, end) + aseg = AudioSegment.from_wav(file_path) + aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] + aseg.export(file_path, format="wav") + except Exception as e: + print(e) + + try: + aseg = AudioSegment.from_wav(file_path) + signed_url = upload_postprocessing_audio(file_path) + threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() + + if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): + upload_memory_recording(file_path, uid, memory_id) + + speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) + words = fal_whisperx(signed_url, speakers_count) + fal_segments = fal_postprocessing(words, aseg.duration_seconds) + + # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL + count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) + new_count = len(''.join([segment.text.strip() for segment in fal_segments])) + print('Prev characters count:', count, 'New characters count:', new_count) + + fal_failed = not fal_segments or new_count < (count * 0.85) + + if fal_failed: + _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) + else: + _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) + + # Store both models results. + memories_db.store_model_segments_result(uid, memory.id, streaming_model, memory.transcript_segments) + memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) + + if not fal_failed: + memory.transcript_segments = fal_segments + + memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later + if fal_failed: + # TODO: FAL fails too much and is fucking expensive. Remove it. + fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) + memory.postprocessing = MemoryPostProcessing( + status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx) + # TODO: consider doing process_memory, if any segment still matched to user or people + return 200, memory + + # Reprocess memory with improved transcription + result: Memory = process_memory(uid, memory.language, memory, force_process=True) + + # Process users emotion, async + if emotional_feedback: + asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) + except Exception as e: + print(e) + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) + return 500, str(e) + + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) + result.postprocessing = MemoryPostProcessing( + status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) + + return 200, result + + +def _get_memory_by_id(uid: str, memory_id: str) -> dict: + memory = memories_db.get_memory(uid, memory_id) + if memory is None or memory.get('deleted', False): + return None + return memory + + +def _delete_postprocessing_audio(file_path): + time.sleep(300) # 5 min + delete_postprocessing_audio(file_path) + os.remove(file_path) + + +async def _process_user_emotion(uid: str, language_code: str, memory: Memory, urls: [str]): + if not any(segment.is_user for segment in memory.transcript_segments): + print(f"_process_user_emotion skipped for {memory.id}") + return + + process_user_emotion(uid, language_code, memory, urls) + + +def _handle_segment_embedding_matching(uid: str, file_path: str, segments: List[TranscriptSegment], aseg: AudioSegment): + if aseg.frame_rate == 16000: + matches = get_speech_profile_matching_predictions(uid, file_path, [s.dict() for s in segments]) + for i, segment in enumerate(segments): + segment.is_user = matches[i]['is_user'] + segment.person_id = matches[i].get('person_id') diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py index 7eb78b2c9..b0e13b915 100644 --- a/backend/utils/other/storage.py +++ b/backend/utils/other/storage.py @@ -149,6 +149,7 @@ def create_signed_postprocessing_audio_url(file_path: str): blob = bucket.blob(file_path) return _get_signed_url(blob, 15) + def upload_postprocessing_audio_bytes(file_path: str, audio_buffer: bytes): bucket = storage_client.bucket(postprocessing_audio_bucket) blob = bucket.blob(file_path) @@ -162,13 +163,13 @@ def upload_sdcard_audio(file_path: str): blob.upload_from_filename(file_path) return f'https://storage.googleapis.com/{postprocessing_audio_bucket}/sdcard/{file_path}' + def download_postprocessing_audio(file_path: str, destination_file_path: str): bucket = storage_client.bucket(postprocessing_audio_bucket) blob = bucket.blob(file_path) blob.download_to_filename(destination_file_path) - # ************************************************ # ************* MEMORIES RECORDINGS ************** # ************************************************ diff --git a/backend/utils/processing_memories.py b/backend/utils/processing_memories.py index d31874c41..057cbd0c6 100644 --- a/backend/utils/processing_memories.py +++ b/backend/utils/processing_memories.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime, timezone -from models.processing_memory import ProcessingMemory, UpdateProcessingMemory +from models.processing_memory import ProcessingMemory, UpdateProcessingMemory, BasicProcessingMemory from models.memory import CreateMemory, PostProcessingModel, PostProcessingStatus, MemoryPostProcessing, TranscriptSegment from utils.memories.process_memory import process_memory from utils.memories.location import get_google_maps_location @@ -23,11 +23,11 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str if not transcript_segments or len(transcript_segments) == 0: print("Transcript segments is invalid") return - timer_start = processing_memory.timer_start + timer_segment_start = processing_memory.timer_segment_start if processing_memory.timer_segment_start else processing_memory.timer_start segment_end = transcript_segments[-1].end new_memory = CreateMemory( - started_at=datetime.fromtimestamp(timer_start, timezone.utc), - finished_at=datetime.fromtimestamp(timer_start + segment_end, timezone.utc), + started_at=datetime.fromtimestamp(timer_segment_start, timezone.utc), + finished_at=datetime.fromtimestamp(timer_segment_start + segment_end, timezone.utc), language=processing_memory.language, transcript_segments=transcript_segments, ) @@ -55,13 +55,13 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str return (memory, messages, processing_memory) -def update_basic_processing_memory(uid: str, update_processing_memory: UpdateProcessingMemory,) -> ProcessingMemory: +def update_basic_processing_memory(uid: str, update_processing_memory: UpdateProcessingMemory,) -> BasicProcessingMemory: # Fetch new - processing_memories = processing_memories_db.get_processing_memories_by_id(uid, [update_processing_memory.id]) - if len(processing_memories) == 0: + processing_memory = processing_memories_db.get_processing_memory_by_id(uid, update_processing_memory.id) + if not processing_memory: print("processing memory is not found") return - processing_memory = ProcessingMemory(**processing_memories[0]) + processing_memory = BasicProcessingMemory(**processing_memory) # geolocation if update_processing_memory.geolocation: diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 4593a1554..979cb2f61 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -4,7 +4,7 @@ from typing import List import websockets -from deepgram import DeepgramClient, DeepgramClientOptions, LiveTranscriptionEvents +from deepgram import DeepgramClient, DeepgramClientOptions, LiveTranscriptionEvents, ListenWebSocketClient from deepgram.clients.live.v1 import LiveOptions import database.notifications as notification_db @@ -61,6 +61,22 @@ # return segments +async def send_initial_file_path(file_path: str, transcript_socket_async_send): + print('send_initial_file_path') + start = time.time() + # Reading and sending in chunks + with open(file_path, "rb") as file: + while True: + chunk = file.read(320) + if not chunk: + break + # print('Uploading', len(chunk)) + await transcript_socket_async_send(bytes(chunk)) + await asyncio.sleep(0.0001) # if it takes too long to transcribe + + print('send_initial_file_path', time.time() - start) + + async def send_initial_file(data: List[List[int]], transcript_socket): print('send_initial_file2') start = time.time() @@ -78,10 +94,10 @@ async def send_initial_file(data: List[List[int]], transcript_socket): async def process_audio_dg( - stream_transcript, stream_id: int, language: str, sample_rate: int, codec: str, channels: int, + stream_transcript, stream_id: int, language: str, sample_rate: int, channels: int, preseconds: int = 0, ): - print('process_audio_dg', language, sample_rate, codec, channels, preseconds) + print('process_audio_dg', language, sample_rate, channels, preseconds) def on_message(self, result, **kwargs): # print(f"Received message from Deepgram") # Log when message is received @@ -127,7 +143,7 @@ def on_error(self, error, **kwargs): print(f"Error: {error}") print("Connecting to Deepgram") # Log before connection attempt - return connect_to_deepgram(on_message, on_error, language, sample_rate, codec, channels) + return connect_to_deepgram(on_message, on_error, language, sample_rate, channels) def process_segments(uid: str, segments: list[dict]): @@ -135,12 +151,42 @@ def process_segments(uid: str, segments: list[dict]): trigger_realtime_integrations(uid, token, segments) -def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, codec: str, channels: int): +def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, channels: int): # 'wss://api.deepgram.com/v1/listen?encoding=linear16&sample_rate=8000&language=$recordingsLanguage&model=nova-2-general&no_delay=true&endpointing=100&interim_results=false&smart_format=true&diarize=true' try: - dg_connection = deepgram.listen.live.v("1") + dg_connection = deepgram.listen.websocket.v("1") dg_connection.on(LiveTranscriptionEvents.Transcript, on_message) dg_connection.on(LiveTranscriptionEvents.Error, on_error) + + def on_open(self, open, **kwargs): + print("Connection Open") + + def on_metadata(self, metadata, **kwargs): + print(f"Metadata: {metadata}") + + def on_speech_started(self, speech_started, **kwargs): + print("Speech Started") + + def on_utterance_end(self, utterance_end, **kwargs): + print("Utterance End") + global is_finals + if len(is_finals) > 0: + utterance = " ".join(is_finals) + print(f"Utterance End: {utterance}") + is_finals = [] + + def on_close(self, close, **kwargs): + print("Connection Closed") + + def on_unhandled(self, unhandled, **kwargs): + print(f"Unhandled Websocket Message: {unhandled}") + + dg_connection.on(LiveTranscriptionEvents.Open, on_open) + dg_connection.on(LiveTranscriptionEvents.Metadata, on_metadata) + dg_connection.on(LiveTranscriptionEvents.SpeechStarted, on_speech_started) + dg_connection.on(LiveTranscriptionEvents.UtteranceEnd, on_utterance_end) + dg_connection.on(LiveTranscriptionEvents.Close, on_close) + dg_connection.on(LiveTranscriptionEvents.Unhandled, on_unhandled) options = LiveOptions( punctuate=True, no_delay=True, @@ -155,7 +201,7 @@ def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, c multichannel=channels > 1, model='nova-2-general', sample_rate=sample_rate, - encoding='linear16' if codec == 'pcm8' or codec == 'pcm16' else 'opus' + encoding='linear16' ) result = dg_connection.start(options) print('Deepgram connection started:', result) @@ -167,10 +213,7 @@ def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, c soniox_valid_languages = ['en'] -# soniox_valid_languages = ['en', 'es', 'fr', 'ko', 'zh', 'it', 'pt', 'de'] - - -async def process_audio_soniox(stream_transcript, stream_id: int, language: str, uid: str): +async def process_audio_soniox(stream_transcript, stream_id: int, sample_rate: int, language: str, uid: str): # Fuck, soniox doesn't even support diarization in languages != english api_key = os.getenv('SONIOX_API_KEY') if not api_key: @@ -182,12 +225,12 @@ async def process_audio_soniox(stream_transcript, stream_id: int, language: str, if language not in soniox_valid_languages: raise ValueError(f"Unsupported language '{language}'. Supported languages are: {soniox_valid_languages}") - has_speech_profile = create_user_speech_profile(uid) # only english too + has_speech_profile = create_user_speech_profile(uid) if uid and sample_rate == 16000 else False # only english too # Construct the initial request with all required and optional parameters request = { 'api_key': api_key, - 'sample_rate_hertz': 16000, + 'sample_rate_hertz': sample_rate, 'include_nonfinal': True, 'enable_endpoint_detection': True, 'enable_streaming_speaker_diarization': True, @@ -216,6 +259,7 @@ async def on_message(): try: async for message in soniox_socket: response = json.loads(message) + # print(response) fw = response['fw'] if not fw: continue @@ -257,7 +301,8 @@ async def on_message(): segments[i]['text'] = segments[i]['text'].strip().replace(' ', '') # print('Soniox:', transcript.replace('', '')) - stream_transcript(segments, stream_id) + if segments: + stream_transcript(segments, stream_id) except websockets.exceptions.ConnectionClosedOK: print("Soniox connection closed normally.") except Exception as e: @@ -276,3 +321,119 @@ async def on_message(): except Exception as e: print(f"Exception in process_audio_soniox: {e}") raise # Re-raise the exception to be handled by the caller + + +LANGUAGE = "en" +CONNECTION_URL = f"wss://eu2.rt.speechmatics.com/v2" + + +async def process_audio_speechmatics(stream_transcript, stream_id: int, sample_rate: int, language: str, preseconds: int = 0): + api_key = os.getenv('SPEECHMATICS_API_KEY') + uri = 'wss://eu2.rt.speechmatics.com/v2' + + request = { + "message": "StartRecognition", + "transcription_config": { + "language": language, + "diarization": "speaker", + "operating_point": "enhanced", + "max_delay_mode": "flexible", + "max_delay": 3, + "enable_partials": False, + "enable_entities": True, + "speaker_diarization_config": {"max_speakers": 4} + }, + "audio_format": {"type": "raw", "encoding": "pcm_s16le", "sample_rate": sample_rate}, + # "audio_events_config": { + # "types": [ + # "laughter", + # "music", + # "applause" + # ] + # } + } + try: + print("Connecting to Speechmatics WebSocket...") + socket = await websockets.connect(uri, extra_headers={"Authorization": f"Bearer {api_key}"}) + print("Connected to Speechmatics WebSocket.") + + await socket.send(json.dumps(request)) + print(f"Sent initial request: {request}") + + async def on_message(): + try: + async for message in socket: + response = json.loads(message) + if response['message'] == 'AudioAdded': + continue + if response['message'] == 'AddTranscript': + results = response['results'] + if not results: + continue + segments = [] + for r in results: + # print(r) + if not r['alternatives']: + continue + + r_data = r['alternatives'][0] + r_type = r['type'] # word | punctuation + r_start = r['start_time'] + r_end = r['end_time'] + + r_content = r_data['content'] + r_confidence = r_data['confidence'] + if r_confidence < 0.4: + print('Low confidence:', r) + continue + r_speaker = r_data['speaker'][1:] if r_data['speaker'] != 'UU' else '1' + speaker = f"SPEAKER_0{r_speaker}" + + is_user = True if r_speaker == '1' and preseconds > 0 else False + if r_start < preseconds: + # print('Skipping word', r_start, r_content) + continue + # print(r_content, r_speaker, [r_start, r_end]) + if not segments: + segments.append({ + 'speaker': speaker, + 'start': r_start, + 'end': r_end, + 'text': r_content, + 'is_user': is_user, + 'person_id': None, + }) + else: + last_segment = segments[-1] + if last_segment['speaker'] == speaker: + last_segment['text'] += f' {r_content}' + last_segment['end'] += r_end + else: + segments.append({ + 'speaker': speaker, + 'start': r_start, + 'end': r_end, + 'text': r_content, + 'is_user': is_user, + 'person_id': None, + }) + + if segments: + stream_transcript(segments, stream_id) + # print('---') + else: + print(response) + except websockets.exceptions.ConnectionClosedOK: + print("Speechmatics connection closed normally.") + except Exception as e: + print(f"Error receiving from Speechmatics: {e}") + finally: + if not socket.closed: + await socket.close() + print("Speechmatics WebSocket closed in on_message.") + + asyncio.create_task(on_message()) + return socket + except Exception as e: + print(f"Exception in process_audio_speechmatics: {e}") + raise diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index a5ca374d8..9870cbacf 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -38,6 +38,8 @@ def is_audio_empty(file_path, sample_rate=8000): def vad_is_empty(file_path, return_segments: bool = False): """Uses vad_modal/vad.py deployment (Best quality)""" try: + file_duration = AudioSegment.from_wav(file_path).duration_seconds + print('vad_is_empty file duration:', file_duration) with open(file_path, 'rb') as file: files = {'file': (file_path.split('/')[-1], file, 'audio/wav')} response = requests.post(os.getenv('HOSTED_VAD_API_URL'), files=files) diff --git a/docs/_assembly/Compile_firmware.md b/docs/_assembly/Compile_firmware.md index bd6d7278b..3677e208f 100644 --- a/docs/_assembly/Compile_firmware.md +++ b/docs/_assembly/Compile_firmware.md @@ -11,11 +11,10 @@ Important: If you purchased an assembled device please skip this step If you purchased an unassembled Friend device or built it yourself using our hardware guide, follow the steps below to flash the firmware: -### Official firmware: +### Want to install a pre-built firmware? Navigate [here](https://docs.omi.me/get_started/Flash_device/) -Go to [Releases](https://github.com/BasedHardware/Omi/releases) in Github, and find the latest official firmware release. To use this firmware, simply download it and skip to step 6. If you would like to build the firmware yourself please follow all the steps below. -### Build your own firmware: +## Build your own firmware: 1. Set up nRF Connect by following the tutorial in this video: [https://youtu.be/EAJdOqsL9m8](https://youtu.be/EAJdOqsL9m8) diff --git a/docs/_assembly/Install_firmware.md b/docs/_assembly/Install_firmware.md new file mode 100644 index 000000000..3ac6a266b --- /dev/null +++ b/docs/_assembly/Install_firmware.md @@ -0,0 +1,7 @@ +--- +layout: default +title: Install firmware (old) +nav_order: 3 +--- + +We've moved! please navigate [here](https://docs.omi.me/get_started/Flash_device/) diff --git a/docs/_developer/Plugins.md b/docs/_developer/Plugins.md deleted file mode 100644 index 11412bf9c..000000000 --- a/docs/_developer/Plugins.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -layout: default -title: Plugins -nav_order: 5 ---- - -We've moved! Go to https://basedhardware.com/plugins diff --git a/docs/_get_started/Flash_device.md b/docs/_get_started/Flash_device.md index 2f9bc4ba9..40f328c85 100644 --- a/docs/_get_started/Flash_device.md +++ b/docs/_get_started/Flash_device.md @@ -1,10 +1,10 @@ --- layout: default -title: Flashing FRIEND Firmware +title: Update FRIEND Firmware nav_order: 3 --- # Video Tutorial -For a visual walkthrough of the flashing process, watch the [Updating Your FRIEND](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/docs/images/updating_your_friend.mov) video. +For a visual walkthrough of the flashing process, watch the [Updating Your FRIEND](https://github.com/BasedHardware/omi/blob/main/docs/images/updating_your_friend.mov) video. # Flashing FRIEND Firmware` @@ -13,9 +13,13 @@ This guide will walk you through the process of flashing the latest firmware ont ## Downloading the Firmware -1. Go to the [FRIEND GitHub repository](https://github.com/BasedHardware/Omi) and navigate to the "Devices > FRIEND > firmware" section. +1. Go to the [FRIEND GitHub repository](https://github.com/BasedHardware/Omi) and navigate to the " FRIEND > firmware" section. 2. Find the latest firmware release and bootloader, then download the corresponding `.uf2` files. +Or download these files + - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.3-firmware/update-xiao_nrf52840_ble_sense_bootloader-0.9.0_nosd.uf2) + - **Firmware:** [firmware1.0.4.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.4-firmware/friend-xiao_nrf52840_ble_sense-1.0.4.uf2) + ## Putting FRIEND into DFU Mode 1. **Locate the DFU Button:** Find the small pin-sized button on the FRIEND device's circuit board (refer to the image below if needed). @@ -30,11 +34,11 @@ This guide will walk you through the process of flashing the latest firmware ont 1. Locate the `.uf2` files you downloaded earlier. 2. Drag and drop the bootloader `.uf2` file onto the `/Volumes/XIAO-SENSE` drive: - - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/bootloader/bootloader0.9.0.uf2) + - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.3-firmware/update-xiao_nrf52840_ble_sense_bootloader-0.9.0_nosd.uf2) 3. The device will automatically eject itself once the bootloader flashing process is complete. 4. After the device forcibly ejects, set the FRIEND device back into DFU mode by double-tapping the reset button. 5. Drag and drop the FRIEND firmware file onto the `/Volumes/XIAO-SENSE` drive: - - **Firmware:** [firmware1.0.4.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/firmware1.0.4.uf2) + - **Firmware:** [firmware1.0.4.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.4-firmware/friend-xiao_nrf52840_ble_sense-1.0.4.uf2) ## Congratulations! @@ -45,4 +49,4 @@ You have successfully flashed the latest firmware onto your FRIEND device. You c Once you've installed the app, follow the in-app instructions to connect your FRIEND device and start exploring its features. -i just added this video to the repo docs/images/updating_your_friend.mov add it to this \ No newline at end of file +i just added this video to the repo docs/images/updating_your_friend.mov add it to this diff --git a/plugins/example/README.md b/plugins/example/README.md index dbff0f747..1379e28b9 100644 --- a/plugins/example/README.md +++ b/plugins/example/README.md @@ -37,7 +37,7 @@ Fill in the values for each variable as needed. 2. Install the requirements.txt by running: `pip install -r requirements.txt` -3. Run the project: `fastapi run dev` This will start the FastAPI application. +3. Run the project: `fastapi run` or `fastapi dev` This will start the FastAPI application. ## Project Structure @@ -64,4 +64,4 @@ Fill in the values for each variable as needed. - `/notion-crm`: Store memory in Notion database - `/news-checker`: Check news based on conversation transcript -For more details on how to use these endpoints, refer to the code documentation or contact the project maintainer. \ No newline at end of file +For more details on how to use these endpoints, refer to the code documentation or contact the project maintainer.