From 25a52c72a42fd5cfd90d8031ea7ff79e9a221c4b Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:35:06 +0530 Subject: [PATCH 01/88] minor speech profile improvement and tab controller fix --- app/lib/main.dart | 3 +++ app/lib/pages/home/page.dart | 1 - app/lib/pages/onboarding/wrapper.dart | 14 ++++++++++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/app/lib/main.dart b/app/lib/main.dart index e55d2bd69..47daaf1a5 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -275,6 +275,9 @@ class _DeciderWidgetState extends State { if (context.read().isConnected) { NotificationService.instance.saveNotificationToken(); } + if (context.read().user != null) { + context.read().setupHasSpeakerProfile(); + } }); super.initState(); } diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index e5111a508..c2abf5b72 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -53,7 +53,6 @@ class _HomePageWrapperState extends State { context.read().periodicConnect('coming from HomePageWrapper'); await context.read().getInitialMemories(); context.read().setSelectedChatPluginId(null); - await context.read().setupHasSpeakerProfile(); }); super.initState(); } diff --git a/app/lib/pages/onboarding/wrapper.dart b/app/lib/pages/onboarding/wrapper.dart index 239e2eb76..b47b9c0f1 100644 --- a/app/lib/pages/onboarding/wrapper.dart +++ b/app/lib/pages/onboarding/wrapper.dart @@ -12,6 +12,7 @@ import 'package:friend_private/pages/onboarding/name/name_widget.dart'; import 'package:friend_private/pages/onboarding/permissions/permissions_widget.dart'; import 'package:friend_private/pages/onboarding/speech_profile_widget.dart'; import 'package:friend_private/pages/onboarding/welcome/page.dart'; +import 'package:friend_private/providers/home_provider.dart'; import 'package:friend_private/providers/onboarding_provider.dart'; import 'package:friend_private/providers/speech_profile_provider.dart'; import 'package:friend_private/services/services.dart'; @@ -39,6 +40,7 @@ class _OnboardingWrapperState extends State with TickerProvid WidgetsBinding.instance.addPostFrameCallback((_) async { if (isSignedIn()) { // && !SharedPreferencesUtil().onboardingCompleted + context.read().setupHasSpeakerProfile(); _goNext(); } }); @@ -51,7 +53,14 @@ class _OnboardingWrapperState extends State with TickerProvid super.dispose(); } - _goNext() => _controller!.animateTo(_controller!.index + 1); + _goNext() { + if (_controller!.index < _controller!.length - 1) { + _controller!.animateTo(_controller!.index + 1); + } else { + routeToPage(context, const HomePageWrapper(), replace: true); + } + // _controller!.animateTo(_controller!.index + 1); + } // TODO: use connection directly Future _getAudioCodec(String deviceId) async { @@ -69,6 +78,7 @@ class _OnboardingWrapperState extends State with TickerProvid AuthComponent( onSignIn: () { MixpanelManager().onboardingStepCompleted('Auth'); + context.read().setupHasSpeakerProfile(); if (SharedPreferencesUtil().onboardingCompleted) { // previous users // Not needed anymore, because AuthProvider already does this @@ -101,7 +111,7 @@ class _OnboardingWrapperState extends State with TickerProvid }, goNext: () async { var provider = context.read(); - if (hasSpeechProfile) { + if (context.read().hasSpeakerProfile) { // previous users routeToPage(context, const HomePageWrapper(), replace: true); } else { From e0ee73ff244bb3a29dec9e9c1921a72023966c8d Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:42:15 +0530 Subject: [PATCH 02/88] move memory heavy lifting to microtask during onboarding --- .../onboarding/memory_created_widget.dart | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/app/lib/pages/onboarding/memory_created_widget.dart b/app/lib/pages/onboarding/memory_created_widget.dart index cb6e21c90..fd61a6711 100644 --- a/app/lib/pages/onboarding/memory_created_widget.dart +++ b/app/lib/pages/onboarding/memory_created_widget.dart @@ -1,4 +1,5 @@ import 'package:flutter/material.dart'; +import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/pages/memories/widgets/memory_list_item.dart'; import 'package:friend_private/pages/memory_detail/memory_detail_provider.dart'; import 'package:friend_private/pages/memory_detail/page.dart'; @@ -9,11 +10,23 @@ import 'package:friend_private/utils/other/temp.dart'; import 'package:gradient_borders/box_borders/gradient_box_border.dart'; import 'package:provider/provider.dart'; -class MemoryCreatedWidget extends StatelessWidget { +Future updateMemoryDetailProvider(BuildContext context, ServerMemory memory) { + return Future.microtask(() { + context.read().addMemory(memory); + context.read().updateMemory(0); + }); +} + +class MemoryCreatedWidget extends StatefulWidget { final VoidCallback goNext; const MemoryCreatedWidget({super.key, required this.goNext}); + @override + State createState() => _MemoryCreatedWidgetState(); +} + +class _MemoryCreatedWidgetState extends State { @override Widget build(BuildContext context) { return Padding( @@ -54,10 +67,8 @@ class MemoryCreatedWidget extends StatelessWidget { ), child: MaterialButton( padding: const EdgeInsets.symmetric(horizontal: 32, vertical: 16), - onPressed: () async { - // goNext(); - context.read().addMemory(provider.memory!); - context.read().updateMemory(0); + onPressed: () { + updateMemoryDetailProvider(context, provider.memory!); MixpanelManager().memoryListItemClicked(provider.memory!, 0); routeToPage(context, MemoryDetailPage(memory: provider.memory!, isFromOnboarding: true)); }, From d67e39f3316278d73472bd57742e865647950d49 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:04:20 +0530 Subject: [PATCH 03/88] minor improvement --- app/lib/pages/onboarding/memory_created_widget.dart | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/app/lib/pages/onboarding/memory_created_widget.dart b/app/lib/pages/onboarding/memory_created_widget.dart index fd61a6711..f6e3fc02d 100644 --- a/app/lib/pages/onboarding/memory_created_widget.dart +++ b/app/lib/pages/onboarding/memory_created_widget.dart @@ -27,6 +27,14 @@ class MemoryCreatedWidget extends StatefulWidget { } class _MemoryCreatedWidgetState extends State { + @override + void initState() { + WidgetsBinding.instance.addPostFrameCallback((_) async { + await updateMemoryDetailProvider(context, context.read().memory!); + }); + super.initState(); + } + @override Widget build(BuildContext context) { return Padding( @@ -68,7 +76,7 @@ class _MemoryCreatedWidgetState extends State { child: MaterialButton( padding: const EdgeInsets.symmetric(horizontal: 32, vertical: 16), onPressed: () { - updateMemoryDetailProvider(context, provider.memory!); + // updateMemoryDetailProvider(context, provider.memory!); MixpanelManager().memoryListItemClicked(provider.memory!, 0); routeToPage(context, MemoryDetailPage(memory: provider.memory!, isFromOnboarding: true)); }, From 511c06474a5cf10c849a6b772ada58f03315fad7 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:04:53 +0530 Subject: [PATCH 04/88] use selectors instead of consumers --- app/lib/pages/capture/widgets/widgets.dart | 2 +- app/lib/pages/memory_detail/page.dart | 37 +++++------ app/lib/widgets/photos_grid.dart | 77 ++++++++++++---------- 3 files changed, 59 insertions(+), 57 deletions(-) diff --git a/app/lib/pages/capture/widgets/widgets.dart b/app/lib/pages/capture/widgets/widgets.dart index ce16dbd54..0cdf0da9f 100644 --- a/app/lib/pages/capture/widgets/widgets.dart +++ b/app/lib/pages/capture/widgets/widgets.dart @@ -290,7 +290,7 @@ getTranscriptWidget( return Column( children: [ - if (photos.isNotEmpty) PhotosGridComponent(photos: photos), + if (photos.isNotEmpty) PhotosGridComponent(), if (segments.isNotEmpty) TranscriptWidget(segments: segments), ], ); diff --git a/app/lib/pages/memory_detail/page.dart b/app/lib/pages/memory_detail/page.dart index d538aaa0b..f3375b546 100644 --- a/app/lib/pages/memory_detail/page.dart +++ b/app/lib/pages/memory_detail/page.dart @@ -209,17 +209,13 @@ class _MemoryDetailPageState extends State with TickerProvider return TabBarView( physics: const NeverScrollableScrollPhysics(), children: [ - Consumer( - builder: (context, provider, child) { + Selector( + selector: (context, provider) => provider.memory.source, + builder: (context, source, child) { return ListView( shrinkWrap: true, - children: provider.memory.source == MemorySource.openglass - ? [ - PhotosGridComponent( - photos: provider.photosData, - ), - const SizedBox(height: 32) - ] + children: source == MemorySource.openglass + ? [const PhotosGridComponent(), const SizedBox(height: 32)] : [const TranscriptWidgets()], ); }, @@ -245,17 +241,18 @@ class SummaryTab extends StatelessWidget { @override Widget build(BuildContext context) { return Selector( - selector: (context, provider) => provider.memory.discarded, - builder: (context, isDiscaarded, child) { - return ListView( - shrinkWrap: true, - children: [ - const GetSummaryWidgets(), - isDiscaarded ? const ReprocessDiscardedWidget() : const GetPluginsWidgets(), - const GetGeolocationWidgets(), - ], - ); - }); + selector: (context, provider) => provider.memory.discarded, + builder: (context, isDiscaarded, child) { + return ListView( + shrinkWrap: true, + children: [ + const GetSummaryWidgets(), + isDiscaarded ? const ReprocessDiscardedWidget() : const GetPluginsWidgets(), + const GetGeolocationWidgets(), + ], + ); + }, + ); } } diff --git a/app/lib/widgets/photos_grid.dart b/app/lib/widgets/photos_grid.dart index 4cb620a97..276f5e9e0 100644 --- a/app/lib/widgets/photos_grid.dart +++ b/app/lib/widgets/photos_grid.dart @@ -1,48 +1,53 @@ import 'dart:convert'; import 'package:flutter/material.dart'; +import 'package:friend_private/pages/memory_detail/memory_detail_provider.dart'; import 'package:friend_private/widgets/dialog.dart'; +import 'package:provider/provider.dart'; import 'package:tuple/tuple.dart'; class PhotosGridComponent extends StatelessWidget { - final List> photos; - const PhotosGridComponent({super.key, required this.photos}); + const PhotosGridComponent({super.key}); @override Widget build(BuildContext context) { - return GridView.builder( - padding: EdgeInsets.zero, - shrinkWrap: true, - scrollDirection: Axis.vertical, - itemCount: photos.length, - physics: const NeverScrollableScrollPhysics(), - itemBuilder: (context, idx) { - return GestureDetector( - onTap: () { - showDialog( - context: context, - builder: (c) { - return getDialog( - context, - () => Navigator.pop(context), - () => Navigator.pop(context), - 'Description', - photos[idx].item2, - singleButton: true, - ); - }); - }, - child: Container( - decoration: BoxDecoration(border: Border.all(color: Colors.grey.shade600, width: 0.5)), - child: Image.memory(base64Decode(photos[idx].item1), fit: BoxFit.cover), - ), - ); - }, - gridDelegate: const SliverGridDelegateWithFixedCrossAxisCount( - crossAxisCount: 3, - crossAxisSpacing: 8, - mainAxisSpacing: 8, - ), - ); + return Selector>>( + selector: (context, provider) => provider.photosData, + builder: (context, photos, child) { + return GridView.builder( + padding: EdgeInsets.zero, + shrinkWrap: true, + scrollDirection: Axis.vertical, + itemCount: photos.length, + physics: const NeverScrollableScrollPhysics(), + itemBuilder: (context, idx) { + return GestureDetector( + onTap: () { + showDialog( + context: context, + builder: (c) { + return getDialog( + context, + () => Navigator.pop(context), + () => Navigator.pop(context), + 'Description', + photos[idx].item2, + singleButton: true, + ); + }); + }, + child: Container( + decoration: BoxDecoration(border: Border.all(color: Colors.grey.shade600, width: 0.5)), + child: Image.memory(base64Decode(photos[idx].item1), fit: BoxFit.cover), + ), + ); + }, + gridDelegate: const SliverGridDelegateWithFixedCrossAxisCount( + crossAxisCount: 3, + crossAxisSpacing: 8, + mainAxisSpacing: 8, + ), + ); + }); } } From 0dbdadf5ab638490effe6327a14b6b040a4c2b72 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:52:07 +0530 Subject: [PATCH 05/88] do not show events that are older than 6 hrs (from start time) and have ended before start time + mins to hrs and days conversion --- app/lib/pages/memory_detail/widgets.dart | 210 +++++++++++------------ 1 file changed, 99 insertions(+), 111 deletions(-) diff --git a/app/lib/pages/memory_detail/widgets.dart b/app/lib/pages/memory_detail/widgets.dart index 7dedb53c3..ab7a281da 100644 --- a/app/lib/pages/memory_detail/widgets.dart +++ b/app/lib/pages/memory_detail/widgets.dart @@ -132,67 +132,19 @@ class GetSummaryWidgets extends StatelessWidget { ); }), memory.structured.actionItems.isNotEmpty ? const SizedBox(height: 40) : const SizedBox.shrink(), - memory.structured.events.isNotEmpty - ? Row( - children: [ - Icon(Icons.event, color: Colors.grey.shade300), - const SizedBox(width: 8), - Text( - 'Events', - style: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 26), - ) - ], - ) - : const SizedBox.shrink(), + // memory.structured.events.isNotEmpty && memory.structured.events.where((e) => e.startsAt.isBefore(memory.startedAt!)).isNotEmpty + // ? Row( + // children: [ + // Icon(Icons.event, color: Colors.grey.shade300), + // const SizedBox(width: 8), + // Text( + // 'Events', + // style: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 26), + // ) + // ], + // ) + // : const SizedBox.shrink(), const EventsListWidget(), - // ...memory.structured.events.mapIndexed((idx, event) { - // print(event.toJson()); - // return ListTile( - // contentPadding: EdgeInsets.zero, - // title: Text( - // event.title, - // style: const TextStyle(color: Colors.white, fontSize: 16, fontWeight: FontWeight.w600), - // ), - // subtitle: Padding( - // padding: const EdgeInsets.only(top: 4.0), - // child: Text( - // '${dateTimeFormat('MMM d, yyyy', event.startsAt)} at ${dateTimeFormat('h:mm a', event.startsAt)} ~ ${event.duration} minutes.', - // style: const TextStyle(color: Colors.grey, fontSize: 15), - // ), - // ), - // trailing: IconButton( - // onPressed: event.created - // ? null - // : () { - // var calEnabled = SharedPreferencesUtil().calendarEnabled; - // var calSelected = SharedPreferencesUtil().calendarId.isNotEmpty; - // if (!calEnabled || !calSelected) { - // routeToPage(context, const CalendarPage()); - // ScaffoldMessenger.of(context).showSnackBar(SnackBar( - // content: Text(!calEnabled - // ? 'Enable calendar integration to add events' - // : 'Select a calendar to add events to'), - // )); - // return; - // } - // context.read().updateEventState(true, idx); - // setMemoryEventsState(memory.id, [idx], [true]); - // CalendarUtil().createEvent( - // event.title, - // event.startsAt, - // event.duration, - // description: event.description, - // ); - // ScaffoldMessenger.of(context).showSnackBar( - // const SnackBar( - // content: Text('Event added to calendar'), - // ), - // ); - // }, - // icon: Icon(event.created ? Icons.check : Icons.add, color: Colors.white), - // ), - // ); - // }), memory.structured.events.isNotEmpty ? const SizedBox(height: 40) : const SizedBox.shrink(), ], ); @@ -208,63 +160,99 @@ class EventsListWidget extends StatelessWidget { Widget build(BuildContext context) { return Consumer( builder: (context, provider, child) { - return ListView.builder( - itemCount: provider.memory.structured.events.length, - shrinkWrap: true, - itemBuilder: (context, idx) { - var event = provider.memory.structured.events[idx]; - return ListTile( - contentPadding: EdgeInsets.zero, - title: Text( - event.title, - style: const TextStyle(color: Colors.white, fontSize: 16, fontWeight: FontWeight.w600), - ), - subtitle: Padding( - padding: const EdgeInsets.only(top: 4.0), - child: Text( - '${dateTimeFormat('MMM d, yyyy', event.startsAt)} at ${dateTimeFormat('h:mm a', event.startsAt)} ~ ${event.duration} minutes.', - style: const TextStyle(color: Colors.grey, fontSize: 15), - ), - ), - trailing: IconButton( - onPressed: event.created - ? null - : () { - var calEnabled = SharedPreferencesUtil().calendarEnabled; - var calSelected = SharedPreferencesUtil().calendarId.isNotEmpty; - if (!calEnabled || !calSelected) { - routeToPage(context, const CalendarPage()); - ScaffoldMessenger.of(context).showSnackBar(SnackBar( - content: Text(!calEnabled - ? 'Enable calendar integration to add events' - : 'Select a calendar to add events to'), - )); - return; - } - context.read().updateEventState(true, idx); - setMemoryEventsState(provider.memory.id, [idx], [true]); - CalendarUtil().createEvent( - event.title, - event.startsAt, - event.duration, - description: event.description, - ); - ScaffoldMessenger.of(context).showSnackBar( - const SnackBar( - content: Text('Event added to calendar'), - ), - ); - }, - icon: Icon(event.created ? Icons.check : Icons.add, color: Colors.white), - ), - ); - }, + return Column( + mainAxisSize: MainAxisSize.min, + children: [ + provider.memory.structured.events.isNotEmpty && + !(provider.memory.structured.events + .where((e) => + e.startsAt.isBefore(provider.memory.startedAt!.add(const Duration(hours: 6))) && + e.startsAt.add(Duration(minutes: e.duration)).isBefore(provider.memory.startedAt!)) + .isNotEmpty) + ? Row( + children: [ + Icon(Icons.event, color: Colors.grey.shade300), + const SizedBox(width: 8), + Text( + 'Events', + style: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 26), + ) + ], + ) + : const SizedBox.shrink(), + ListView.builder( + itemCount: provider.memory.structured.events.length, + shrinkWrap: true, + itemBuilder: (context, idx) { + var event = provider.memory.structured.events[idx]; + if (event.startsAt.isBefore(provider.memory.startedAt!.add(const Duration(hours: 6))) && + event.startsAt.add(Duration(minutes: event.duration)).isBefore(provider.memory.startedAt!)) { + return const SizedBox.shrink(); + } + return ListTile( + contentPadding: EdgeInsets.zero, + title: Text( + event.title, + style: const TextStyle(color: Colors.white, fontSize: 16, fontWeight: FontWeight.w600), + ), + subtitle: Padding( + padding: const EdgeInsets.only(top: 4.0), + child: Text( + '${dateTimeFormat('MMM d, yyyy', event.startsAt)} at ${dateTimeFormat('h:mm a', event.startsAt)} ~ ${minutesConversion(event.duration)}.', + style: const TextStyle(color: Colors.grey, fontSize: 15), + ), + ), + trailing: IconButton( + onPressed: event.created + ? null + : () { + var calEnabled = SharedPreferencesUtil().calendarEnabled; + var calSelected = SharedPreferencesUtil().calendarId.isNotEmpty; + if (!calEnabled || !calSelected) { + routeToPage(context, const CalendarPage()); + ScaffoldMessenger.of(context).showSnackBar(SnackBar( + content: Text(!calEnabled + ? 'Enable calendar integration to add events' + : 'Select a calendar to add events to'), + )); + return; + } + context.read().updateEventState(true, idx); + setMemoryEventsState(provider.memory.id, [idx], [true]); + CalendarUtil().createEvent( + event.title, + event.startsAt, + event.duration, + description: event.description, + ); + ScaffoldMessenger.of(context).showSnackBar( + const SnackBar( + content: Text('Event added to calendar'), + ), + ); + }, + icon: Icon(event.created ? Icons.check : Icons.add, color: Colors.white), + ), + ); + }, + ), + ], ); }, ); } } +String minutesConversion(int minutes) { + if (minutes < 60) { + return '$minutes minutes'; + } else if (minutes < 1440) { + return '${minutes / 60} hours'; + } else { + return '${minutes / 1440} days'; + } +} + class GetEditTextField extends StatefulWidget { final bool enabled; final String overview; From b567dee5f51e1d947b65ab77248e1e82c52a0daa Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:03:49 +0530 Subject: [PATCH 06/88] geolocation alignment fix --- app/lib/pages/memory_detail/widgets.dart | 1 + 1 file changed, 1 insertion(+) diff --git a/app/lib/pages/memory_detail/widgets.dart b/app/lib/pages/memory_detail/widgets.dart index ab7a281da..f0e31c1f6 100644 --- a/app/lib/pages/memory_detail/widgets.dart +++ b/app/lib/pages/memory_detail/widgets.dart @@ -542,6 +542,7 @@ class GetGeolocationWidgets extends StatelessWidget { return provider.memory.geolocation; }, builder: (context, geolocation, child) { return Column( + crossAxisAlignment: CrossAxisAlignment.start, children: geolocation == null ? [] : [ From a8dddfc8c518d8f798187d7d0547aa58b5444528 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:06:33 +0530 Subject: [PATCH 07/88] fix no plugin triggered container scrolling --- app/lib/pages/memory_detail/widgets.dart | 1 + 1 file changed, 1 insertion(+) diff --git a/app/lib/pages/memory_detail/widgets.dart b/app/lib/pages/memory_detail/widgets.dart index f0e31c1f6..76eaf3a50 100644 --- a/app/lib/pages/memory_detail/widgets.dart +++ b/app/lib/pages/memory_detail/widgets.dart @@ -488,6 +488,7 @@ class GetPluginsWidgets extends StatelessWidget { }, child: ListView( shrinkWrap: true, + physics: const NeverScrollableScrollPhysics(), children: [ const SizedBox(height: 32), Text( From 9fcd176370234bccc9f135acb0093067357e73ff Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 18:15:07 +0530 Subject: [PATCH 08/88] use dialog for change name widget --- .../pages/settings/change_name_widget.dart | 117 +++++++++++++++ app/lib/pages/settings/personal_details.dart | 141 ------------------ app/lib/pages/settings/profile.dart | 16 +- 3 files changed, 125 insertions(+), 149 deletions(-) create mode 100644 app/lib/pages/settings/change_name_widget.dart delete mode 100644 app/lib/pages/settings/personal_details.dart diff --git a/app/lib/pages/settings/change_name_widget.dart b/app/lib/pages/settings/change_name_widget.dart new file mode 100644 index 000000000..02aa856f3 --- /dev/null +++ b/app/lib/pages/settings/change_name_widget.dart @@ -0,0 +1,117 @@ +import 'dart:io'; + +import 'package:firebase_auth/firebase_auth.dart'; +import 'package:flutter/cupertino.dart'; +import 'package:flutter/material.dart'; +import 'package:friend_private/backend/auth.dart'; +import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/utils/alerts/app_snackbar.dart'; + +class ChangeNameWidget extends StatefulWidget { + const ChangeNameWidget({super.key}); + + @override + State createState() => _ChangeNameWidgetState(); +} + +class _ChangeNameWidgetState extends State { + late TextEditingController nameController; + User? user; + bool isSaving = false; + + @override + void initState() { + user = getFirebaseUser(); + nameController = TextEditingController(text: user?.displayName ?? ''); + super.initState(); + } + + @override + Widget build(BuildContext context) { + if (Platform.isIOS) { + return CupertinoAlertDialog( + content: Padding( + padding: const EdgeInsets.all(8.0), + child: Column( + children: [ + const Text('Enter your given name'), + const SizedBox(height: 8), + CupertinoTextField( + controller: nameController, + placeholderStyle: const TextStyle(color: Colors.white54), + style: const TextStyle(color: Colors.white), + ), + ], + ), + ), + actions: [ + CupertinoDialogAction( + textStyle: const TextStyle(color: Colors.white), + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text('Cancel'), + ), + CupertinoDialogAction( + textStyle: const TextStyle(color: Colors.white), + onPressed: () { + if (nameController.text.isEmpty || nameController.text.trim().isEmpty) { + AppSnackbar.showSnackbarError('Name cannot be empty'); + return; + } + SharedPreferencesUtil().givenName = nameController.text; + updateGivenName(nameController.text); + AppSnackbar.showSnackbar('Name updated successfully!'); + Navigator.of(context).pop(); + }, + child: const Text('Save'), + ), + ], + ); + } else { + return AlertDialog( + content: Padding( + padding: const EdgeInsets.all(8.0), + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + const Text('Enter your given name'), + const SizedBox(height: 8), + TextField( + controller: nameController, + style: const TextStyle(color: Colors.white), + ), + ], + ), + ), + actions: [ + TextButton( + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text( + 'Cancel', + style: TextStyle(color: Colors.white), + ), + ), + TextButton( + onPressed: () { + if (nameController.text.isEmpty || nameController.text.trim().isEmpty) { + AppSnackbar.showSnackbarError('Name cannot be empty'); + return; + } + SharedPreferencesUtil().givenName = nameController.text; + updateGivenName(nameController.text); + AppSnackbar.showSnackbar('Name updated successfully!'); + Navigator.of(context).pop(); + }, + child: const Text( + 'Save', + style: TextStyle(color: Colors.white), + ), + ), + ], + ); + } + } +} diff --git a/app/lib/pages/settings/personal_details.dart b/app/lib/pages/settings/personal_details.dart deleted file mode 100644 index 7594afe76..000000000 --- a/app/lib/pages/settings/personal_details.dart +++ /dev/null @@ -1,141 +0,0 @@ -import 'package:firebase_auth/firebase_auth.dart'; -import 'package:flutter/material.dart'; -import 'package:friend_private/backend/auth.dart'; -import 'package:friend_private/utils/alerts/app_snackbar.dart'; -import 'package:gradient_borders/gradient_borders.dart'; - -class PersonalDetails extends StatefulWidget { - const PersonalDetails({super.key}); - - @override - State createState() => _PersonalDetailsState(); -} - -class _PersonalDetailsState extends State { - late TextEditingController nameController; - User? user; - bool isSaving = false; - - @override - void initState() { - user = getFirebaseUser(); - nameController = TextEditingController(text: user?.displayName ?? ''); - super.initState(); - } - - @override - Widget build(BuildContext context) { - return Scaffold( - backgroundColor: Theme.of(context).colorScheme.primary, - appBar: AppBar( - backgroundColor: Theme.of(context).colorScheme.primary, - actions: [ - MaterialButton( - onPressed: () async { - if (nameController.text.isEmpty || nameController.text.trim().isEmpty) { - AppSnackbar.showSnackbarError('Name cannot be empty'); - return; - } - setState(() => isSaving = true); - await updateGivenName(nameController.text); - setState(() => isSaving = false); - AppSnackbar.showSnackbar('Name updated successfully!'); - Navigator.of(context).pop(); - }, - color: Colors.transparent, - elevation: 0, - child: isSaving - ? const Center( - child: Padding( - padding: EdgeInsets.all(8.0), - child: CircularProgressIndicator( - color: Colors.white, - ), - ), - ) - : const Padding( - padding: EdgeInsets.symmetric(horizontal: 4.0), - child: Text( - 'Save', - style: TextStyle(color: Colors.deepPurple, fontWeight: FontWeight.w600, fontSize: 16), - ), - ), - ) - ], - ), - body: Padding( - padding: const EdgeInsets.only(left: 18, right: 18), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const SizedBox( - height: 30, - ), - TextFormField( - enabled: true, - controller: nameController, - // textCapitalization: TextCapitalization.sentences, - obscureText: false, - // canRequestFocus: true, - textAlign: TextAlign.start, - textAlignVertical: TextAlignVertical.center, - decoration: InputDecoration( - hintText: 'Enter your full name', - hintStyle: const TextStyle(fontSize: 14.0, color: Colors.grey), - floatingLabelBehavior: FloatingLabelBehavior.always, - label: Text( - 'Given Name', - style: TextStyle( - color: Colors.grey.shade200, - fontSize: 16, - ), - ), - border: GradientOutlineInputBorder( - borderRadius: BorderRadius.circular(8), - gradient: const LinearGradient( - colors: [ - Color.fromARGB(255, 202, 201, 201), - Color.fromARGB(255, 159, 158, 158), - ], - ), - ), - ), - style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), - ), - const SizedBox( - height: 30, - ), - // TextFormField( - // enabled: true, - // obscureText: false, - // textAlign: TextAlign.start, - // textAlignVertical: TextAlignVertical.center, - // readOnly: true, - // initialValue: user?.email, - // decoration: InputDecoration( - // floatingLabelBehavior: FloatingLabelBehavior.always, - // label: Text( - // 'Email Address', - // style: TextStyle( - // color: Colors.grey.shade200, - // fontSize: 16, - // ), - // ), - // border: GradientOutlineInputBorder( - // borderRadius: BorderRadius.circular(8), - // gradient: const LinearGradient( - // colors: [ - // Color.fromARGB(255, 202, 201, 201), - // Color.fromARGB(255, 159, 158, 158), - // ], - // ), - // ), - // ), - // style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), - // ), - ], - ), - ), - ); - } -} diff --git a/app/lib/pages/settings/profile.dart b/app/lib/pages/settings/profile.dart index 8c776eba6..631325ec6 100644 --- a/app/lib/pages/settings/profile.dart +++ b/app/lib/pages/settings/profile.dart @@ -2,7 +2,7 @@ import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/pages/facts/page.dart'; -import 'package:friend_private/pages/settings/personal_details.dart'; +import 'package:friend_private/pages/settings/change_name_widget.dart'; import 'package:friend_private/pages/settings/privacy.dart'; import 'package:friend_private/pages/settings/recordings_storage_permission.dart'; import 'package:friend_private/pages/speech_profile/page.dart'; @@ -66,14 +66,14 @@ class _ProfilePageState extends State { ), subtitle: Text(SharedPreferencesUtil().givenName.isEmpty ? 'Not set' : SharedPreferencesUtil().givenName), trailing: const Icon(Icons.person, size: 20), - onTap: () { - // TODO: change to a dialog + onTap: () async { MixpanelManager().pageOpened('Profile Change Name'); - showModalBottomSheet( - context: context, - builder: (c) { - return const PersonalDetails(); - }); + await showDialog( + context: context, + builder: (BuildContext context) { + return const ChangeNameWidget(); + }, + ).whenComplete(() => setState(() {})); }, ), const SizedBox(height: 24), From a8e4ed8c8afba182a94930d8ae956d21af48c3e6 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:58:32 +0530 Subject: [PATCH 09/88] show initial messages --- app/lib/pages/chat/widgets/ai_message.dart | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/app/lib/pages/chat/widgets/ai_message.dart b/app/lib/pages/chat/widgets/ai_message.dart index 3afac1cdf..393b89bc8 100644 --- a/app/lib/pages/chat/widgets/ai_message.dart +++ b/app/lib/pages/chat/widgets/ai_message.dart @@ -116,8 +116,8 @@ class _AIMessageState extends State { style: TextStyle(fontSize: 15.0, fontWeight: FontWeight.w500, color: Colors.grey.shade300), )), if (widget.message.id != 1) _getCopyButton(context), // RESTORE ME - // if (message.id == 1 && displayOptions) const SizedBox(height: 8), - // if (message.id == 1 && displayOptions) ..._getInitialOptions(context), + if (widget.displayOptions) const SizedBox(height: 8), + if (widget.displayOptions) ..._getInitialOptions(context), if (messageMemories.isNotEmpty) ...[ const SizedBox(height: 16), for (var data in messageMemories.indexed) ...[ @@ -256,7 +256,7 @@ class _AIMessageState extends State { _getInitialOption(BuildContext context, String optionText) { return GestureDetector( child: Container( - padding: const EdgeInsets.symmetric(horizontal: 12.0, vertical: 8), + padding: const EdgeInsets.symmetric(horizontal: 12.0, vertical: 10), width: double.maxFinite, decoration: BoxDecoration( color: Colors.grey.shade900, @@ -273,11 +273,11 @@ class _AIMessageState extends State { _getInitialOptions(BuildContext context) { return [ const SizedBox(height: 8), - _getInitialOption(context, 'What tasks do I have from yesterday?'), + _getInitialOption(context, 'What\'s been on my mind a lot?'), const SizedBox(height: 8), - _getInitialOption(context, 'What conversations did I have with John?'), + _getInitialOption(context, 'Did I forget to follow up on something?'), const SizedBox(height: 8), - _getInitialOption(context, 'What advise have I received about entrepreneurship?'), + _getInitialOption(context, 'What\'s the funniest thing I\'ve said lately?'), ]; } } From 39947de2b942648fc7916a075acda9ef560628b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Thu, 19 Sep 2024 05:31:00 +0700 Subject: [PATCH 10/88] Centralize logic of postprocess memory to util --- backend/routers/postprocessing.py | 231 +----------------------------- backend/routers/transcribe.py | 2 +- backend/utils/postprocessing.py | 129 +++++++++++++++++ 3 files changed, 136 insertions(+), 226 deletions(-) create mode 100644 backend/utils/postprocessing.py diff --git a/backend/routers/postprocessing.py b/backend/routers/postprocessing.py index bc147ec18..f039dd4e9 100644 --- a/backend/routers/postprocessing.py +++ b/backend/routers/postprocessing.py @@ -1,22 +1,9 @@ -import asyncio -import os -import threading -import time from fastapi import APIRouter, Depends, HTTPException, UploadFile -from pydub import AudioSegment -import database.memories as memories_db -from database.users import get_user_store_recording_permission from models.memory import * -from routers.memories import _get_memory_by_id -from utils.memories.process_memory import process_memory, process_user_emotion +from utils.postprocessing import postprocess_memory_util from utils.other import endpoints as auth -from utils.other.storage import upload_postprocessing_audio, \ - delete_postprocessing_audio, upload_memory_recording -from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing -from utils.stt.speech_profile import get_speech_profile_matching_predictions -from utils.stt.vad import vad_is_empty router = APIRouter() @@ -39,221 +26,15 @@ def postprocess_memory( TODO: should consider storing non beautified segments, and beautify on read? TODO: post llm process here would be great, sometimes whisper x outputs without punctuation """ - memory_data = _get_memory_by_id(uid, memory_id) - memory = Memory(**memory_data) - if memory.discarded: - print('postprocess_memory: Memory is discarded') - raise HTTPException(status_code=400, detail="Memory is discarded") - - if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: - print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') - raise HTTPException(status_code=400, detail="Memory can't be post-processed again") + # Save file file_path = f"_temp/{memory_id}_{file.filename}" with open(file_path, 'wb') as f: f.write(file.file.read()) - aseg = AudioSegment.from_wav(file_path) - if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 - # TODO: fix app, sometimes audio uploaded is wrong, is too short. - print('postprocess_memory: Audio duration is too short, seems wrong.') - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) - raise HTTPException(status_code=500, detail="Audio duration is too short, seems wrong.") - - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) - - try: - # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. - vad_segments = vad_is_empty(file_path, return_segments=True) - if vad_segments: - start = vad_segments[0]['start'] - end = vad_segments[-1]['end'] - aseg = AudioSegment.from_wav(file_path) - aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] - aseg.export(file_path, format="wav") - except Exception as e: - print(e) - - try: - aseg = AudioSegment.from_wav(file_path) - signed_url = upload_postprocessing_audio(file_path) - threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() - - if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): - upload_memory_recording(file_path, uid, memory_id) - - speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) - words = fal_whisperx(signed_url, speakers_count) - fal_segments = fal_postprocessing(words, aseg.duration_seconds) - - # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL - count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) - new_count = len(''.join([segment.text.strip() for segment in fal_segments])) - print('Prev characters count:', count, 'New characters count:', new_count) - - fal_failed = not fal_segments or new_count < (count * 0.85) - - if fal_failed: - _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) - else: - _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) - - # Store both models results. - memories_db.store_model_segments_result(uid, memory.id, 'deepgram_streaming', memory.transcript_segments) - memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) - - if not fal_failed: - memory.transcript_segments = fal_segments - - memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later - if fal_failed: - # TODO: FAL fails too much and is fucking expensive. Remove it. - fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) - memory.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx) - # TODO: consider doing process_memory, if any segment still matched to user or people - return memory - - # Reprocess memory with improved transcription - result: Memory = process_memory(uid, memory.language, memory, force_process=True) - - # Process users emotion, async - if emotional_feedback: - asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) - except Exception as e: - print(e) - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) - raise HTTPException(status_code=500, detail=str(e)) + # Process + status_code, result = postprocess_memory_util(memory_id=memory_id, uid=uid, file_path=file_path, emotional_feedback=emotional_feedback, streaming_model="deepgram_streaming") + if status_code != 200: + raise HTTPException(status_code=status_code, detail=result) - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) - result.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) return result - - -# TODO: Move to util -def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): - """ - The objective of this endpoint, is to get the best possible transcript from the audio file. - Instead of storing the initial deepgram result, doing a full post-processing with whisper-x. - This increases the quality of transcript by at least 20%. - Which also includes a better summarization. - Which helps us create better vectors for the memory. - And improves the overall experience of the user. - - TODO: Try Nvidia Nemo ASR as suggested by @jhonnycombs https://huggingface.co/spaces/hf-audio/open_asr_leaderboard - That + pyannote diarization 3.1, is as good as it gets. Then is only hardware improvements. - TODO: should consider storing non beautified segments, and beautify on read? - TODO: post llm process here would be great, sometimes whisper x outputs without punctuation - """ - memory_data = _get_memory_by_id(uid, memory_id) - memory = Memory(**memory_data) - if memory.discarded: - print('postprocess_memory: Memory is discarded') - return 400, "Memory is discarded" - - if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: - print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') - return 400, "Memory can't be post-processed again" - - aseg = AudioSegment.from_wav(file_path) - if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 - # TODO: fix app, sometimes audio uploaded is wrong, is too short. - print('postprocess_memory: Audio duration is too short, seems wrong.') - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) - return (500, "Audio duration is too short, seems wrong.") - - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) - - try: - # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. - vad_segments = vad_is_empty(file_path, return_segments=True) - if vad_segments: - start = vad_segments[0]['start'] - end = vad_segments[-1]['end'] - aseg = AudioSegment.from_wav(file_path) - aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] - aseg.export(file_path, format="wav") - except Exception as e: - print(e) - - try: - aseg = AudioSegment.from_wav(file_path) - signed_url = upload_postprocessing_audio(file_path) - threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() - - if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): - upload_memory_recording(file_path, uid, memory_id) - - speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) - words = fal_whisperx(signed_url, speakers_count) - fal_segments = fal_postprocessing(words, aseg.duration_seconds) - - # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL - count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) - new_count = len(''.join([segment.text.strip() for segment in fal_segments])) - print('Prev characters count:', count, 'New characters count:', new_count) - - fal_failed = not fal_segments or new_count < (count * 0.85) - - if fal_failed: - _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) - else: - _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) - - # Store both models results. - memories_db.store_model_segments_result(uid, memory.id, streaming_model, memory.transcript_segments) - memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) - - if not fal_failed: - memory.transcript_segments = fal_segments - - memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later - if fal_failed: - # TODO: FAL fails too much and is fucking expensive. Remove it. - fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) - memory.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx) - # TODO: consider doing process_memory, if any segment still matched to user or people - return (200, memory) - - # Reprocess memory with improved transcription - result: Memory = process_memory(uid, memory.language, memory, force_process=True) - - # Process users emotion, async - if emotional_feedback: - asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) - except Exception as e: - print(e) - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) - return 500, str(e) - - memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) - result.postprocessing = MemoryPostProcessing( - status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) - - return 200, result - - -def _delete_postprocessing_audio(file_path): - time.sleep(300) # 5 min - delete_postprocessing_audio(file_path) - os.remove(file_path) - - -async def _process_user_emotion(uid: str, language_code: str, memory: Memory, urls: [str]): - if not any(segment.is_user for segment in memory.transcript_segments): - print(f"_process_user_emotion skipped for {memory.id}") - return - - process_user_emotion(uid, language_code, memory, urls) - - -def _handle_segment_embedding_matching(uid: str, file_path: str, segments: List[TranscriptSegment], aseg: AudioSegment): - if aseg.frame_rate == 16000: - matches = get_speech_profile_matching_predictions(uid, file_path, [s.dict() for s in segments]) - for i, segment in enumerate(segments): - segment.is_user = matches[i]['is_user'] - segment.person_id = matches[i].get('person_id') diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 3d582c5b4..11bdda4a3 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -15,7 +15,7 @@ from models.memory import Memory, TranscriptSegment from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated from models.processing_memory import ProcessingMemory -from routers.postprocessing import postprocess_memory_util +from utils.postprocessing import postprocess_memory_util from utils.audio import create_wav_from_bytes, merge_wav_files from utils.memories.process_memory import process_memory from utils.other.storage import upload_postprocessing_audio diff --git a/backend/utils/postprocessing.py b/backend/utils/postprocessing.py new file mode 100644 index 000000000..a8a46dade --- /dev/null +++ b/backend/utils/postprocessing.py @@ -0,0 +1,129 @@ +import asyncio +import os +import threading +import time + +from pydub import AudioSegment + +import database.memories as memories_db +from database.users import get_user_store_recording_permission +from models.memory import * +from routers.memories import _get_memory_by_id +from utils.memories.process_memory import process_memory, process_user_emotion +from utils.other.storage import upload_postprocessing_audio, \ + delete_postprocessing_audio, upload_memory_recording +from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing +from utils.stt.speech_profile import get_speech_profile_matching_predictions +from utils.stt.vad import vad_is_empty + +def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): + memory_data = _get_memory_by_id(uid, memory_id) + memory = Memory(**memory_data) + if memory.discarded: + print('postprocess_memory: Memory is discarded') + return 400, "Memory is discarded" + + if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: + print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') + return 400, "Memory can't be post-processed again" + + aseg = AudioSegment.from_wav(file_path) + if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 + # TODO: fix app, sometimes audio uploaded is wrong, is too short. + print('postprocess_memory: Audio duration is too short, seems wrong.') + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) + return (500, "Audio duration is too short, seems wrong.") + + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) + + try: + # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. + vad_segments = vad_is_empty(file_path, return_segments=True) + if vad_segments: + start = vad_segments[0]['start'] + end = vad_segments[-1]['end'] + aseg = AudioSegment.from_wav(file_path) + aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] + aseg.export(file_path, format="wav") + except Exception as e: + print(e) + + try: + aseg = AudioSegment.from_wav(file_path) + signed_url = upload_postprocessing_audio(file_path) + threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() + + if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): + upload_memory_recording(file_path, uid, memory_id) + + speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) + words = fal_whisperx(signed_url, speakers_count) + fal_segments = fal_postprocessing(words, aseg.duration_seconds) + + # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL + count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) + new_count = len(''.join([segment.text.strip() for segment in fal_segments])) + print('Prev characters count:', count, 'New characters count:', new_count) + + fal_failed = not fal_segments or new_count < (count * 0.85) + + if fal_failed: + _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) + else: + _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) + + # Store both models results. + memories_db.store_model_segments_result(uid, memory.id, streaming_model, memory.transcript_segments) + memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) + + if not fal_failed: + memory.transcript_segments = fal_segments + + memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later + if fal_failed: + # TODO: FAL fails too much and is fucking expensive. Remove it. + fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) + memory.postprocessing = MemoryPostProcessing( + status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx) + # TODO: consider doing process_memory, if any segment still matched to user or people + return (200, memory) + + # Reprocess memory with improved transcription + result: Memory = process_memory(uid, memory.language, memory, force_process=True) + + # Process users emotion, async + if emotional_feedback: + asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) + except Exception as e: + print(e) + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) + return 500, str(e) + + memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) + result.postprocessing = MemoryPostProcessing( + status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) + + return 200, result + + +def _delete_postprocessing_audio(file_path): + time.sleep(300) # 5 min + delete_postprocessing_audio(file_path) + os.remove(file_path) + + +async def _process_user_emotion(uid: str, language_code: str, memory: Memory, urls: [str]): + if not any(segment.is_user for segment in memory.transcript_segments): + print(f"_process_user_emotion skipped for {memory.id}") + return + + process_user_emotion(uid, language_code, memory, urls) + + +def _handle_segment_embedding_matching(uid: str, file_path: str, segments: List[TranscriptSegment], aseg: AudioSegment): + if aseg.frame_rate == 16000: + matches = get_speech_profile_matching_predictions(uid, file_path, [s.dict() for s in segments]) + for i, segment in enumerate(segments): + segment.is_user = matches[i]['is_user'] + segment.person_id = matches[i].get('person_id') From 6ac2d5db774107be632c736ab0c3b0b73cd4b0dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Thu, 19 Sep 2024 05:34:41 +0700 Subject: [PATCH 11/88] Clarify returns with tuples on postprocessing memory util --- backend/utils/postprocessing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/backend/utils/postprocessing.py b/backend/utils/postprocessing.py index a8a46dade..a5ee9b9f9 100644 --- a/backend/utils/postprocessing.py +++ b/backend/utils/postprocessing.py @@ -16,16 +16,17 @@ from utils.stt.speech_profile import get_speech_profile_matching_predictions from utils.stt.vad import vad_is_empty + def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): memory_data = _get_memory_by_id(uid, memory_id) memory = Memory(**memory_data) if memory.discarded: print('postprocess_memory: Memory is discarded') - return 400, "Memory is discarded" + return (400, "Memory is discarded") if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') - return 400, "Memory can't be post-processed again" + return (400, "Memory can't be post-processed again") aseg = AudioSegment.from_wav(file_path) if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 @@ -98,13 +99,13 @@ def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_ except Exception as e: print(e) memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) - return 500, str(e) + return (500, str(e)) memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) result.postprocessing = MemoryPostProcessing( status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) - return 200, result + return (200, result) def _delete_postprocessing_audio(file_path): From 1c35300354bb3c5103a5315910cd370278405c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Thu, 19 Sep 2024 05:47:28 +0700 Subject: [PATCH 12/88] Move logic postprocess memory to utils > memories as a sub func --- backend/requirements.txt | 1 + backend/routers/postprocessing.py | 2 +- backend/routers/transcribe.py | 3 +-- .../postprocess_memory.py} | 13 +++++++++++-- 4 files changed, 14 insertions(+), 5 deletions(-) rename backend/utils/{postprocessing.py => memories/postprocess_memory.py} (94%) diff --git a/backend/requirements.txt b/backend/requirements.txt index 91ea5daa1..b985bcdbe 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -226,3 +226,4 @@ wcwidth==0.2.13 websockets==12.0 yarl==1.9.4 pyogg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f +opuslib==3.0.1 diff --git a/backend/routers/postprocessing.py b/backend/routers/postprocessing.py index f039dd4e9..ef236fc07 100644 --- a/backend/routers/postprocessing.py +++ b/backend/routers/postprocessing.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, UploadFile from models.memory import * -from utils.postprocessing import postprocess_memory_util +from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util from utils.other import endpoints as auth router = APIRouter() diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 11bdda4a3..43fa826e1 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -15,7 +15,7 @@ from models.memory import Memory, TranscriptSegment from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated from models.processing_memory import ProcessingMemory -from utils.postprocessing import postprocess_memory_util +from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util from utils.audio import create_wav_from_bytes, merge_wav_files from utils.memories.process_memory import process_memory from utils.other.storage import upload_postprocessing_audio @@ -217,7 +217,6 @@ def stream_audio(audio_buffer): stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None ) - except Exception as e: print(f"Initial processing error: {e}") websocket_close_code = 1011 diff --git a/backend/utils/postprocessing.py b/backend/utils/memories/postprocess_memory.py similarity index 94% rename from backend/utils/postprocessing.py rename to backend/utils/memories/postprocess_memory.py index a5ee9b9f9..4cb1670cd 100644 --- a/backend/utils/postprocessing.py +++ b/backend/utils/memories/postprocess_memory.py @@ -8,7 +8,6 @@ import database.memories as memories_db from database.users import get_user_store_recording_permission from models.memory import * -from routers.memories import _get_memory_by_id from utils.memories.process_memory import process_memory, process_user_emotion from utils.other.storage import upload_postprocessing_audio, \ delete_postprocessing_audio, upload_memory_recording @@ -17,8 +16,11 @@ from utils.stt.vad import vad_is_empty -def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): +def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): memory_data = _get_memory_by_id(uid, memory_id) + if not memory_data: + return (404, "Memory not found") + memory = Memory(**memory_data) if memory.discarded: print('postprocess_memory: Memory is discarded') @@ -108,6 +110,13 @@ def postprocess_memory_util(memory_id: str, file_path: str, uid: str, emotional_ return (200, result) +def _get_memory_by_id(uid: str, memory_id: str) -> dict: + memory = memories_db.get_memory(uid, memory_id) + if memory is None or memory.get('deleted', False): + return None + return memory + + def _delete_postprocessing_audio(file_path): time.sleep(300) # 5 min delete_postprocessing_audio(file_path) From 700ebab54898d3379de110f2e722620d3af5f814 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Wed, 18 Sep 2024 16:09:33 -0700 Subject: [PATCH 13/88] fix api url --- app/lib/env/env.dart | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/lib/env/env.dart b/app/lib/env/env.dart index d68f2fee4..8a14a8046 100644 --- a/app/lib/env/env.dart +++ b/app/lib/env/env.dart @@ -13,10 +13,10 @@ abstract class Env { static String? get mixpanelProjectToken => _instance.mixpanelProjectToken; - // static String? get apiBaseUrl => _instance.apiBaseUrl; + static String? get apiBaseUrl => _instance.apiBaseUrl; // static String? get apiBaseUrl => 'https://based-hardware-development--backened-dev-api.modal.run/'; - static String? get apiBaseUrl => 'https://camel-lucky-reliably.ngrok-free.app/'; + // static String? get apiBaseUrl => 'https://camel-lucky-reliably.ngrok-free.app/'; // static String? get apiBaseUrl => 'https://mutual-fun-boar.ngrok-free.app/'; static String? get growthbookApiKey => _instance.growthbookApiKey; From 96c53bff1ad02bbcde131db8d032dedbc472833d Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Wed, 18 Sep 2024 17:16:16 -0700 Subject: [PATCH 14/88] minor fix rag doesn't use deleted vectors --- app/lib/pages/memory_capturing/page.dart | 2 +- backend/database/memories.py | 7 ++++++- backend/database/vector_db.py | 6 ++++-- backend/routers/memories.py | 1 + 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/app/lib/pages/memory_capturing/page.dart b/app/lib/pages/memory_capturing/page.dart index f9a1968e2..bb5ca9d19 100644 --- a/app/lib/pages/memory_capturing/page.dart +++ b/app/lib/pages/memory_capturing/page.dart @@ -77,7 +77,7 @@ class _MemoryCapturingPageState extends State with TickerPr const SizedBox(width: 4), const Text("🎙️"), const SizedBox(width: 4), - const Expanded(child: Text("In Progress")), + const Expanded(child: Text("In progress")), ], ), ), diff --git a/backend/database/memories.py b/backend/database/memories.py index f12ca4257..d5cf60e25 100644 --- a/backend/database/memories.py +++ b/backend/database/memories.py @@ -56,11 +56,13 @@ def delete_memory(uid, memory_id): def filter_memories_by_date(uid, start_date, end_date): + # TODO: check utc comparison or not? user_ref = db.collection('users').document(uid) query = ( user_ref.collection('memories') .where(filter=FieldFilter('created_at', '>=', start_date)) .where(filter=FieldFilter('created_at', '<=', end_date)) + .where(filter=FieldFilter('deleted', '==', False)) .where(filter=FieldFilter('discarded', '==', False)) .order_by('created_at', direction=firestore.Query.DESCENDING) ) @@ -85,7 +87,10 @@ def get_memories_by_id(uid, memory_ids): memories = [] for doc in docs: if doc.exists: - memories.append(doc.to_dict()) + data = doc.to_dict() + if data.get('deleted') or data.get('discarded'): + continue + memories.append(data) return memories diff --git a/backend/database/vector_db.py b/backend/database/vector_db.py index 26bf92157..eb18ae053 100644 --- a/backend/database/vector_db.py +++ b/backend/database/vector_db.py @@ -42,7 +42,7 @@ def upsert_vectors( print('upsert_vectors', res) -def query_vectors(query: str, uid: str, starts_at: int = None, ends_at: int = None, k:int = 5) -> List[str]: +def query_vectors(query: str, uid: str, starts_at: int = None, ends_at: int = None, k: int = 5) -> List[str]: filter_data = {'uid': uid} if starts_at is not None: filter_data['created_at'] = {'$gte': starts_at, '$lte': ends_at} @@ -55,4 +55,6 @@ def query_vectors(query: str, uid: str, starts_at: int = None, ends_at: int = No def delete_vector(memory_id: str): - index.delete(ids=[memory_id], namespace="ns1") + # TODO: does this work? + result = index.delete(ids=[memory_id], namespace="ns1") + print('delete_vector', result) diff --git a/backend/routers/memories.py b/backend/routers/memories.py index 8599ac194..5d7913819 100644 --- a/backend/routers/memories.py +++ b/backend/routers/memories.py @@ -113,6 +113,7 @@ def get_memory_transcripts_by_models(memory_id: str, uid: str = Depends(auth.get @router.delete("/v1/memories/{memory_id}", status_code=204, tags=['memories']) def delete_memory(memory_id: str, uid: str = Depends(auth.get_current_user_uid)): + print('delete_memory', memory_id, uid) memories_db.delete_memory(uid, memory_id) delete_vector(memory_id) return {"status": "Ok"} From 59072b157b97620fd925e5dda710348d7529d3a8 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Wed, 18 Sep 2024 17:19:38 -0700 Subject: [PATCH 15/88] minor copy change --- app/lib/pages/onboarding/name/name_widget.dart | 2 +- app/lib/pages/settings/change_name_widget.dart | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/app/lib/pages/onboarding/name/name_widget.dart b/app/lib/pages/onboarding/name/name_widget.dart index f8b3b7db8..73f6691ac 100644 --- a/app/lib/pages/onboarding/name/name_widget.dart +++ b/app/lib/pages/onboarding/name/name_widget.dart @@ -47,7 +47,7 @@ class _NameWidgetState extends State { textAlign: TextAlign.center, textAlignVertical: TextAlignVertical.center, decoration: InputDecoration( - hintText: 'Enter your given name', + hintText: 'How Omi should call you?', // label: const Text('What should Omi call you?'), hintStyle: const TextStyle(fontSize: 14, color: Colors.grey), // border: UnderlineInputBorder( diff --git a/app/lib/pages/settings/change_name_widget.dart b/app/lib/pages/settings/change_name_widget.dart index 02aa856f3..41060c291 100644 --- a/app/lib/pages/settings/change_name_widget.dart +++ b/app/lib/pages/settings/change_name_widget.dart @@ -34,7 +34,7 @@ class _ChangeNameWidgetState extends State { padding: const EdgeInsets.all(8.0), child: Column( children: [ - const Text('Enter your given name'), + const Text('How Omi should call you?'), const SizedBox(height: 8), CupertinoTextField( controller: nameController, @@ -75,7 +75,7 @@ class _ChangeNameWidgetState extends State { child: Column( mainAxisSize: MainAxisSize.min, children: [ - const Text('Enter your given name'), + const Text('How Omi should call you?'), const SizedBox(height: 8), TextField( controller: nameController, From 2eff627f824b231faf932694c810c8478e493f5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Wed, 18 Sep 2024 16:49:43 +0700 Subject: [PATCH 16/88] Recall pure socket --- app/lib/utils/pure_socket.dart | 387 +++++++++++++++++++++++++++++++++ 1 file changed, 387 insertions(+) create mode 100644 app/lib/utils/pure_socket.dart diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart new file mode 100644 index 000000000..4f6536afc --- /dev/null +++ b/app/lib/utils/pure_socket.dart @@ -0,0 +1,387 @@ +import 'dart:async'; +import 'dart:convert'; +import 'dart:math'; + +import 'package:flutter/material.dart'; +import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/backend/schema/message_event.dart'; +import 'package:friend_private/backend/schema/transcript_segment.dart'; +import 'package:friend_private/env/env.dart'; +import 'package:friend_private/services/notification_service.dart'; +import 'package:instabug_flutter/instabug_flutter.dart'; +import 'package:internet_connection_checker_plus/internet_connection_checker_plus.dart'; +import 'package:web_socket_channel/io.dart'; +import 'package:web_socket_channel/status.dart' as status; +import 'package:web_socket_channel/web_socket_channel.dart'; + +enum PureSocketStatus { notConnected, connecting, connected, disconnected } + +abstract class IPureSocketListener { + void onMessage(dynamic message); + void onClosed(); + void onError(Object err, StackTrace trace); + + void onInternetConnectionFailed() {} + + void onMaxRetriesReach() {} +} + +abstract class IPureSocket { + Future connect(); + void disconnect(); + void send(dynamic message); + + void onInternetSatusChanged(InternetStatus status); + + void onMessage(dynamic message); + void onClosed(); + void onError(Object err, StackTrace trace); +} + +class PureSocketMessage { + String? raw; +} + +class PureCore { + late InternetConnection internetConnection; + + factory PureCore() => _instance; + + /// The singleton instance of [PureCore]. + static final _instance = PureCore.createInstance(); + + PureCore.createInstance() { + internetConnection = InternetConnection.createInstance( + customCheckOptions: [ + InternetCheckOption( + uri: Uri.parse(Env.apiBaseUrl!), + timeout: const Duration( + seconds: 30, + )), + ], + ); + } +} + +class PureSocket implements IPureSocket { + StreamSubscription? _internetStatusListener; + InternetStatus? _internetStatus; + Timer? _internetLostDelayTimer; + + WebSocketChannel? _channel; + PureSocketStatus _status = PureSocketStatus.notConnected; + IPureSocketListener? _listener; + + int _retries = 0; + + String url; + + PureSocket(this.url) { + _internetStatusListener = PureCore().internetConnection.onStatusChange.listen((InternetStatus status) { + onInternetSatusChanged(status); + }); + } + + WebSocketChannel get channel { + if (_channel == null) { + throw Exception('Socket is not connected'); + } + return _channel!; + } + + void setListener(IPureSocketListener listener) { + _listener = listener; + } + + @override + Future connect() async { + return await _connect(); + } + + Future _connect() async { + if (_status == PureSocketStatus.connecting || _status == PureSocketStatus.connected) { + return false; + } + + _channel = IOWebSocketChannel.connect( + url, + pingInterval: const Duration(seconds: 10), + connectTimeout: const Duration(seconds: 30), + ); + if (_channel?.ready == null) { + return false; + } + + _status = PureSocketStatus.connecting; + await _channel?.ready; + _status = PureSocketStatus.connected; + _retries = 0; + + final that = this; + + _channel?.stream.listen( + (message) { + that.onMessage(message); + }, + onError: (err, trace) { + that.onError(err, trace); + }, + onDone: () { + that.onClosed(); + }, + cancelOnError: true, + ); + + return true; + } + + @override + void disconnect() { + _status = PureSocketStatus.disconnected; + _cleanUp(); + } + + Future _cleanUp() async { + _internetLostDelayTimer?.cancel(); + _internetStatusListener?.cancel(); + await _channel?.sink.close(status.goingAway); + } + + @override + void onClosed() { + _status = PureSocketStatus.disconnected; + debugPrint("Socket closed"); + _listener?.onClosed(); + } + + @override + void onError(Object err, StackTrace trace) { + _status = PureSocketStatus.disconnected; + print("Error: ${err}"); + debugPrintStack(stackTrace: trace); + + _listener?.onError(err, trace); + + CrashReporting.reportHandledCrash(err, trace, level: NonFatalExceptionLevel.error); + } + + @override + void onMessage(dynamic message) { + debugPrint("[Socket] Message $message"); + _listener?.onMessage(message); + } + + @override + void send(message) { + _channel?.sink.add(message); + } + + void _reconnect() async { + const int initialBackoffTimeMs = 1000; // 1 second + const double multiplier = 1.5; + const int maxRetries = 7; + + if (_status == PureSocketStatus.connecting || _status == PureSocketStatus.connected) { + debugPrint("[Socket] Can not reconnect, because socket is $_status"); + return; + } + + await _cleanUp(); + + var ok = await _connect(); + if (ok) { + return; + } + + // retry + int waitInMilliseconds = pow(multiplier, _retries).toInt() * initialBackoffTimeMs; + await Future.delayed(Duration(milliseconds: waitInMilliseconds)); + _retries++; + if (_retries >= maxRetries) { + debugPrint("[Socket] Reach max retries $maxRetries"); + _listener?.onMaxRetriesReach(); + return; + } + _reconnect(); + } + + @override + void onInternetSatusChanged(InternetStatus status) { + _internetStatus = status; + switch (status) { + case InternetStatus.connected: + if (_status == PureSocketStatus.connected || _status == PureSocketStatus.connecting) { + return; + } + _reconnect(); + break; + case InternetStatus.disconnected: + var that = this; + _internetLostDelayTimer?.cancel(); + _internetLostDelayTimer = Timer(const Duration(seconds: 60), () { + if (_internetStatus != InternetStatus.disconnected) { + return; + } + + that.disconnect(); + _listener?.onInternetConnectionFailed(); + }); + + break; + } + } +} + +abstract interface class ITransctipSegmentSocketServiceListener { + void onMessageEventReceived(ServerMessageEvent event); + void onSegmentReceived(List segments); + void onError(Object err); + void onClosed(); +} + +class TranscripSegmentSocketService implements IPureSocketListener { + late PureSocket _socket; + final Map _listeners = {}; + + int sampleRate; + String codec; + bool includeSpeechProfile; + + factory TranscripSegmentSocketService() { + if (_instance == null) { + throw Exception("TranscripSegmentSocketService is not initiated"); + } + + return _instance!; + } + + /// The singleton instance of [TranscripSegmentSocketService]. + static TranscripSegmentSocketService? _instance; + + TranscripSegmentSocketService.create( + this.sampleRate, + this.codec, + this.includeSpeechProfile, + ) { + final recordingsLanguage = SharedPreferencesUtil().recordingsLanguage; + var params = + '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}&include_speech_profile=$includeSpeechProfile'; + String url = '${Env.apiBaseUrl!.replaceAll('https', 'wss')}listen$params'; + + _socket = PureSocket(url); + _socket.setListener(this); + } + + TranscripSegmentSocketService.createInstance( + this.sampleRate, + this.codec, + this.includeSpeechProfile, + ) { + _instance = TranscripSegmentSocketService.createInstance(sampleRate, codec, includeSpeechProfile); + } + + void subscribe(Object context, ITransctipSegmentSocketServiceListener listener) { + if (_listeners.containsKey(context.hashCode)) { + _listeners.remove(context.hashCode); + } + _listeners.putIfAbsent(context.hashCode, () => listener); + } + + void unsubscribe(Object context) { + if (_listeners.containsKey(context.hashCode)) { + _listeners.remove(context.hashCode); + } + } + + void start() { + _socket.connect(); + } + + void stop() { + _socket.disconnect(); + _listeners.clear(); + } + + @override + void onClosed() { + _listeners.forEach((k, v) { + v.onClosed(); + }); + } + + @override + void onError(Object err, StackTrace trace) { + _listeners.forEach((k, v) { + v.onError(err); + }); + } + + @override + void onMessage(event) { + if (event == 'ping') return; + + // Decode json + dynamic jsonEvent; + try { + jsonEvent = jsonDecode(event); + } on FormatException catch (e) { + debugPrint(e.toString()); + } + if (jsonEvent == null) { + debugPrint("Can not decode message event json $event"); + return; + } + + // Transcript segments + if (jsonEvent is List) { + var segments = jsonEvent; + if (segments.isNotEmpty) { + return; + } + _listeners.forEach((k, v) { + v.onSegmentReceived(segments.map((e) => TranscriptSegment.fromJson(e)).toList()); + }); + return; + } + + debugPrint(event); + + // Message event + if (jsonEvent.containsKey("type")) { + var event = ServerMessageEvent.fromJson(jsonEvent); + _listeners.forEach((k, v) { + v.onMessageEventReceived(event); + }); + return; + } + + debugPrint(event.toString()); + } + + @override + void onInternetConnectionFailed() { + debugPrint("onInternetConnectionFailed"); + + // Send notification + NotificationService.instance.clearNotification(3); + NotificationService.instance.createNotification( + notificationId: 3, + title: 'Internet Connection Lost', + body: 'Your device is offline. Transcription is paused until connection is restored.', + ); + } + + @override + void onMaxRetriesReach() { + debugPrint("onMaxRetriesReach"); + + // Send notification + NotificationService.instance.clearNotification(2); + NotificationService.instance.createNotification( + notificationId: 2, + title: 'Connection Issue 🚨', + body: 'Unable to connect to the transcript service.' + ' Please restart the app or contact support if the problem persists.', + ); + } +} From e5ea7a1b23ab27df3fad810a70366949d7267750 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Thu, 19 Sep 2024 05:09:45 +0700 Subject: [PATCH 17/88] Integrate capture provider with pure socket --- app/lib/backend/schema/bt_device.dart | 12 +- app/lib/main.dart | 2 +- app/lib/providers/capture_provider.dart | 262 ++++++++++++------------ app/lib/services/services.dart | 5 + app/lib/services/sockets.dart | 62 ++++++ app/lib/utils/pure_socket.dart | 98 +++++---- 6 files changed, 267 insertions(+), 174 deletions(-) create mode 100644 app/lib/services/sockets.dart diff --git a/app/lib/backend/schema/bt_device.dart b/app/lib/backend/schema/bt_device.dart index b14fb7f7e..999518f0a 100644 --- a/app/lib/backend/schema/bt_device.dart +++ b/app/lib/backend/schema/bt_device.dart @@ -3,7 +3,17 @@ import 'package:friend_private/services/device_connections.dart'; import 'package:friend_private/services/frame_connection.dart'; import 'package:friend_private/utils/ble/gatt_utils.dart'; -enum BleAudioCodec { pcm16, pcm8, mulaw16, mulaw8, opus, unknown } +enum BleAudioCodec { + pcm16, + pcm8, + mulaw16, + mulaw8, + opus, + unknown; + + @override + String toString() => mapCodecToName(this); +} String mapCodecToName(BleAudioCodec codec) { switch (codec) { diff --git a/app/lib/main.dart b/app/lib/main.dart index 47daaf1a5..3aedcc952 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -168,7 +168,7 @@ class _MyAppState extends State with WidgetsBindingObserver { ChangeNotifierProxyProvider3( create: (context) => CaptureProvider(), update: (BuildContext context, memory, message, wsProvider, CaptureProvider? previous) => - (previous?..updateProviderInstances(memory, message, wsProvider)) ?? CaptureProvider(), + (previous?..updateProviderInstances(memory, message)) ?? CaptureProvider(), ), ChangeNotifierProxyProvider2( create: (context) => DeviceProvider(), diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 8d286d29f..4a2bca967 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -20,7 +20,6 @@ import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/pages/capture/logic/openglass_mixin.dart'; import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/providers/message_provider.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/analytics/growthbook.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; @@ -31,19 +30,20 @@ import 'package:friend_private/utils/features/calendar.dart'; import 'package:friend_private/utils/logger.dart'; import 'package:friend_private/utils/memories/integrations.dart'; import 'package:friend_private/utils/memories/process.dart'; -import 'package:friend_private/utils/websockets.dart'; +import 'package:friend_private/utils/pure_socket.dart'; import 'package:permission_handler/permission_handler.dart'; import 'package:uuid/uuid.dart'; -class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifierMixin { +class CaptureProvider extends ChangeNotifier + with OpenGlassMixin, MessageNotifierMixin + implements ITransctipSegmentSocketServiceListener { MemoryProvider? memoryProvider; MessageProvider? messageProvider; - WebSocketProvider? webSocketProvider; + TranscripSegmentSocketService? _socket; - void updateProviderInstances(MemoryProvider? mp, MessageProvider? p, WebSocketProvider? wsProvider) { + void updateProviderInstances(MemoryProvider? mp, MessageProvider? p) { memoryProvider = mp; messageProvider = p; - webSocketProvider = wsProvider; notifyListeners(); } @@ -99,7 +99,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie } void setMemoryCreating(bool value) { - print('set memory creating ${value}'); + debugPrint('set memory creating ${value}'); memoryCreating = value; notifyListeners(); } @@ -143,7 +143,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie emotionalFeedback: GrowthbookUtil().isOmiFeedbackEnabled(), ); if (result?.result == null) { - print("Can not update processing memory, result null"); + debugPrint("Can not update processing memory, result null"); } } @@ -154,7 +154,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future _onMemoryCreated(ServerMessageEvent event) async { if (event.memory == null) { - print("Memory is not found, processing memory ${event.processingMemoryId}"); + debugPrint("Memory is not found, processing memory ${event.processingMemoryId}"); return; } _processOnMemoryCreated(event.memory, event.messages ?? []); @@ -167,7 +167,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future _onMemoryPostProcessSuccess(String memoryId) async { var memory = await getMemoryById(memoryId); if (memory == null) { - print("Memory is not found $memoryId"); + debugPrint("Memory is not found $memoryId"); return; } @@ -177,7 +177,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future _onMemoryPostProcessFailed(String memoryId) async { var memory = await getMemoryById(memoryId); if (memory == null) { - print("Memory is not found $memoryId"); + debugPrint("Memory is not found $memoryId"); return; } @@ -310,7 +310,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie // Create new socket session // Warn: should have a better solution to keep the socket alived - await webSocketProvider?.closeWebSocketWithoutReconnect('reset new memory session'); + await _socket?.stop(reason: 'reset new memory session'); await initiateWebsocket(); } @@ -338,124 +338,30 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie BleAudioCodec? audioCodec, int? sampleRate, ]) async { - // setWebSocketConnecting(true); - print('initiateWebsocket in capture_provider'); + debugPrint('initiateWebsocket in capture_provider'); + BleAudioCodec codec = audioCodec ?? SharedPreferencesUtil().deviceCodec; sampleRate ??= (codec == BleAudioCodec.opus ? 16000 : 8000); - print('is ws null: ${webSocketProvider == null}'); - await webSocketProvider?.initWebSocket( - codec: codec, - sampleRate: sampleRate, - includeSpeechProfile: true, - newMemoryWatch: true, - // Warn: need clarify about initiateWebsocket - onConnectionSuccess: () { - print('inside onConnectionSuccess'); - if (segments.isNotEmpty) { - // means that it was a reconnection, so we need to reset - streamStartedAtSecond = null; - secondsMissedOnReconnect = (DateTime.now().difference(firstStreamReceivedAt!).inSeconds); - } - print('bottom in onConnectionSuccess'); - notifyListeners(); - }, - onConnectionFailed: (err) { - print('inside onConnectionFailed'); - print('err: $err'); - notifyListeners(); - }, - onConnectionClosed: (int? closeCode, String? closeReason) { - print('inside onConnectionClosed'); - print('closeCode: $closeCode'); - // connection was closed, either on resetState, or by backend, or by some other reason. - // setState(() {}); - }, - onConnectionError: (err) { - print('inside onConnectionError'); - print('err: $err'); - // connection was okay, but then failed. - notifyListeners(); - }, - onMessageEventReceived: (ServerMessageEvent event) { - if (event.type == MessageEventType.newMemoryCreating) { - _onMemoryCreating(); - return; - } - if (event.type == MessageEventType.newMemoryCreated) { - _onMemoryCreated(event); - return; - } + debugPrint('is ws null: ${_socket == null}'); - if (event.type == MessageEventType.newMemoryCreateFailed) { - _onMemoryCreateFailed(); - return; - } - - if (event.type == MessageEventType.newProcessingMemoryCreated) { - if (event.processingMemoryId == null) { - print("New processing memory created message event is invalid"); - return; - } - _onProcessingMemoryCreated(event.processingMemoryId!); - return; - } - - if (event.type == MessageEventType.memoryPostProcessingSuccess) { - if (event.memoryId == null) { - print("Post proccess message event is invalid"); - return; - } - _onMemoryPostProcessSuccess(event.memoryId!); - return; - } - - if (event.type == MessageEventType.memoryPostProcessingFailed) { - if (event.memoryId == null) { - print("Post proccess message event is invalid"); - return; - } - _onMemoryPostProcessFailed(event.memoryId!); - return; - } - }, - onMessageReceived: (List newSegments) { - if (newSegments.isEmpty) return; - - if (segments.isEmpty) { - debugPrint('newSegments: ${newSegments.last}'); - // TODO: small bug -> when memory A creates, and memory B starts, memory B will clean a lot more seconds than available, - // losing from the audio the first part of the recording. All other parts are fine. - FlutterForegroundTask.sendDataToTask(jsonEncode({'location': true})); - var currentSeconds = (audioStorage?.frames.length ?? 0) ~/ 100; - var removeUpToSecond = newSegments[0].start.toInt(); - audioStorage?.removeFramesRange(fromSecond: 0, toSecond: min(max(currentSeconds - 5, 0), removeUpToSecond)); - firstStreamReceivedAt = DateTime.now(); - } + // TODO: thinh, socket + _socket = await ServiceManager.instance().socket.memory(codec: codec, sampleRate: sampleRate); + if (_socket == null) { + throw Exception("Can not create new memory socket"); + } - streamStartedAtSecond ??= newSegments[0].start; - TranscriptSegment.combineSegments( - segments, - newSegments, - toRemoveSeconds: streamStartedAtSecond ?? 0, - toAddSeconds: secondsMissedOnReconnect ?? 0, - ); - triggerTranscriptSegmentReceivedEvents(newSegments, conversationId, sendMessageToChat: (v) { - messageProvider?.addMessage(v); - }); - - debugPrint('Memory creation timer restarted'); - _memoryCreationTimer?.cancel(); - _memoryCreationTimer = - Timer(const Duration(seconds: quietSecondsForMemoryCreation), () => _createPhotoCharacteristicMemory()); - setHasTranscripts(true); - notifyListeners(); - }, - ); + // Ok + _socket?.subscribe(this, this); + if (segments.isNotEmpty) { + // means that it was a reconnection, so we need to reset + streamStartedAtSecond = null; + secondsMissedOnReconnect = (DateTime.now().difference(firstStreamReceivedAt!).inSeconds); + } } Future streamAudioToWs(String id, BleAudioCodec codec) async { - print('streamAudioToWs in capture_provider'); + debugPrint('streamAudioToWs in capture_provider'); audioStorage = WavBytesUtil(codec: codec); if (_bleBytesStream != null) { _bleBytesStream?.cancel(); @@ -468,8 +374,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie final trimmedValue = value.sublist(3); // TODO: if this (0,3) is not removed, deepgram can't seem to be able to detect the audio. // https://developers.deepgram.com/docs/determining-your-audio-format-for-live-streaming-audio - if (webSocketProvider?.wsConnectionState == WebsocketConnectionStatus.connected) { - webSocketProvider?.websocketChannel?.sink.add(trimmedValue); + if (_socket?.state == SocketServiceState.connected) { + _socket?.send(trimmedValue); } }, ); @@ -544,7 +450,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie Future resetForSpeechProfile() async { closeBleStream(); - await webSocketProvider?.closeWebSocketWithoutReconnect('reset for speech profile'); + // TODO: thinh, socket, check why we need reset for speech profile here + await _socket?.stop(reason: 'reset for speech profile'); setAudioBytesConnected(false); notifyListeners(); } @@ -667,8 +574,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie } Future _manageWebSocketConnection(bool codecChanged, bool isFromSpeechProfile) async { - if (codecChanged || webSocketProvider?.wsConnectionState != WebsocketConnectionStatus.connected) { - await webSocketProvider?.closeWebSocketWithoutReconnect('reset state $isFromSpeechProfile'); + if (codecChanged || _socket?.state != SocketServiceState.connected) { + await _socket?.stop(reason: 'reset state $isFromSpeechProfile'); // if (!isFromSpeechProfile) { await initiateWebsocket(); // } @@ -676,7 +583,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie } Future initiateFriendAudioStreaming(bool isFromSpeechProfile) async { - print('connectedDevice: $connectedDevice in initiateFriendAudioStreaming'); + debugPrint('connectedDevice: $connectedDevice in initiateFriendAudioStreaming'); if (connectedDevice == null) return; BleAudioCodec codec = await _getAudioCodec(connectedDevice!.id); @@ -715,7 +622,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie isGlasses = await _hasPhotoStreamingCharacteristic(connectedDevice!.id); if (!isGlasses) return; await openGlassProcessing(connectedDevice!, (p) {}, setHasTranscripts); - webSocketProvider?.closeWebSocketWithoutReconnect('reset state open glass'); + _socket?.stop(reason: 'reset state open glass'); notifyListeners(); } @@ -733,6 +640,7 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie void dispose() { _bleBytesStream?.cancel(); _memoryCreationTimer?.cancel(); + _socket?.unsubscribe(this); super.dispose(); } @@ -746,8 +654,8 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie // record await ServiceManager.instance().mic.start(onByteReceived: (bytes) { - if (webSocketProvider?.wsConnectionState == WebsocketConnectionStatus.connected) { - webSocketProvider?.websocketChannel?.sink.add(bytes); + if (_socket?.state == SocketServiceState.connected) { + _socket?.send(bytes); } }, onRecording: () { updateRecordingState(RecordingState.record); @@ -761,4 +669,96 @@ class CaptureProvider extends ChangeNotifier with OpenGlassMixin, MessageNotifie stopStreamRecording() { ServiceManager.instance().mic.stop(); } + + // Socket handling + + @override + void onClosed() { + debugPrint('socket is closed'); + } + + @override + void onError(Object err) { + debugPrint('err: $err'); + notifyListeners(); + } + + @override + void onMessageEventReceived(ServerMessageEvent event) { + if (event.type == MessageEventType.newMemoryCreating) { + _onMemoryCreating(); + return; + } + + if (event.type == MessageEventType.newMemoryCreated) { + _onMemoryCreated(event); + return; + } + + if (event.type == MessageEventType.newMemoryCreateFailed) { + _onMemoryCreateFailed(); + return; + } + + if (event.type == MessageEventType.newProcessingMemoryCreated) { + if (event.processingMemoryId == null) { + debugPrint("New processing memory created message event is invalid"); + return; + } + _onProcessingMemoryCreated(event.processingMemoryId!); + return; + } + + if (event.type == MessageEventType.memoryPostProcessingSuccess) { + if (event.memoryId == null) { + debugPrint("Post proccess message event is invalid"); + return; + } + _onMemoryPostProcessSuccess(event.memoryId!); + return; + } + + if (event.type == MessageEventType.memoryPostProcessingFailed) { + if (event.memoryId == null) { + debugPrint("Post proccess message event is invalid"); + return; + } + _onMemoryPostProcessFailed(event.memoryId!); + return; + } + } + + @override + void onSegmentReceived(List newSegments) { + if (newSegments.isEmpty) return; + + if (segments.isEmpty) { + debugPrint('newSegments: ${newSegments.last}'); + // TODO: small bug -> when memory A creates, and memory B starts, memory B will clean a lot more seconds than available, + // losing from the audio the first part of the recording. All other parts are fine. + FlutterForegroundTask.sendDataToTask(jsonEncode({'location': true})); + var currentSeconds = (audioStorage?.frames.length ?? 0) ~/ 100; + var removeUpToSecond = newSegments[0].start.toInt(); + audioStorage?.removeFramesRange(fromSecond: 0, toSecond: min(max(currentSeconds - 5, 0), removeUpToSecond)); + firstStreamReceivedAt = DateTime.now(); + } + + streamStartedAtSecond ??= newSegments[0].start; + TranscriptSegment.combineSegments( + segments, + newSegments, + toRemoveSeconds: streamStartedAtSecond ?? 0, + toAddSeconds: secondsMissedOnReconnect ?? 0, + ); + triggerTranscriptSegmentReceivedEvents(newSegments, conversationId, sendMessageToChat: (v) { + messageProvider?.addMessage(v); + }); + + debugPrint('Memory creation timer restarted'); + _memoryCreationTimer?.cancel(); + _memoryCreationTimer = + Timer(const Duration(seconds: quietSecondsForMemoryCreation), () => _createPhotoCharacteristicMemory()); + setHasTranscripts(true); + notifyListeners(); + } } diff --git a/app/lib/services/services.dart b/app/lib/services/services.dart index dc489b15e..127fbeb3d 100644 --- a/app/lib/services/services.dart +++ b/app/lib/services/services.dart @@ -6,10 +6,12 @@ import 'package:flutter/material.dart'; import 'package:flutter_background_service/flutter_background_service.dart'; import 'package:flutter_sound/flutter_sound.dart'; import 'package:friend_private/services/devices.dart'; +import 'package:friend_private/services/sockets.dart'; class ServiceManager { late IMicRecorderService _mic; late IDeviceService _device; + late ISocketService _socket; static ServiceManager? _instance; @@ -19,6 +21,7 @@ class ServiceManager { runner: BackgroundService(), ); sm._device = DeviceService(); + sm._socket = SocketServicePool(); return sm; } @@ -35,6 +38,8 @@ class ServiceManager { IDeviceService get device => _device; + ISocketService get socket => _socket; + static void init() { if (_instance != null) { throw Exception("Service manager is initiated"); diff --git a/app/lib/services/sockets.dart b/app/lib/services/sockets.dart new file mode 100644 index 000000000..0fffecf57 --- /dev/null +++ b/app/lib/services/sockets.dart @@ -0,0 +1,62 @@ +import 'package:friend_private/backend/schema/bt_device.dart'; +import 'package:friend_private/utils/pure_socket.dart'; + +abstract class ISocketService { + void start(); + void stop(); + + Future memory({required BleAudioCodec codec, required int sampleRate, bool force = false}); + TranscripSegmentSocketService speechProfile(); +} + +abstract interface class ISocketServiceSubsciption {} + +class SocketServicePool extends ISocketService { + TranscripSegmentSocketService? _memory; + TranscripSegmentSocketService? _speechProfile; + + @override + void start() { + // TODO: implement start + } + + @override + void stop() async { + await _memory?.stop(); + await _speechProfile?.stop(); + } + + // Warn: Should use a better solution to prevent race conditions + bool memoryMutex = false; + @override + Future memory( + {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { + while (memoryMutex) { + await Future.delayed(const Duration(milliseconds: 50)); + } + memoryMutex = true; + + if (!force && + _memory?.codec == codec && + _memory?.sampleRate == sampleRate && + _memory?.state == SocketServiceState.connected) { + return _memory; + } + + // new socket + await _memory?.stop(); + + _memory = MemoryTranscripSegmentSocketService.create(sampleRate, codec); + await _memory?.start(); + + memoryMutex = false; + + return _memory; + } + + @override + TranscripSegmentSocketService speechProfile() { + // TODO: implement speechProfile + throw UnimplementedError(); + } +} diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 4f6536afc..70dbe3a8e 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -4,6 +4,7 @@ import 'dart:math'; import 'package:flutter/material.dart'; import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/backend/schema/message_event.dart'; import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/env/env.dart'; @@ -11,7 +12,7 @@ import 'package:friend_private/services/notification_service.dart'; import 'package:instabug_flutter/instabug_flutter.dart'; import 'package:internet_connection_checker_plus/internet_connection_checker_plus.dart'; import 'package:web_socket_channel/io.dart'; -import 'package:web_socket_channel/status.dart' as status; +import 'package:web_socket_channel/status.dart' as socket_channel_status; import 'package:web_socket_channel/web_socket_channel.dart'; enum PureSocketStatus { notConnected, connecting, connected, disconnected } @@ -28,7 +29,7 @@ abstract class IPureSocketListener { abstract class IPureSocket { Future connect(); - void disconnect(); + Future disconnect(); void send(dynamic message); void onInternetSatusChanged(InternetStatus status); @@ -69,7 +70,16 @@ class PureSocket implements IPureSocket { Timer? _internetLostDelayTimer; WebSocketChannel? _channel; + WebSocketChannel get channel { + if (_channel == null) { + throw Exception('Socket is not connected'); + } + return _channel!; + } + PureSocketStatus _status = PureSocketStatus.notConnected; + PureSocketStatus get status => _status; + IPureSocketListener? _listener; int _retries = 0; @@ -82,13 +92,6 @@ class PureSocket implements IPureSocket { }); } - WebSocketChannel get channel { - if (_channel == null) { - throw Exception('Socket is not connected'); - } - return _channel!; - } - void setListener(IPureSocketListener listener) { _listener = listener; } @@ -136,15 +139,15 @@ class PureSocket implements IPureSocket { } @override - void disconnect() { + Future disconnect() async { _status = PureSocketStatus.disconnected; - _cleanUp(); + await _cleanUp(); } Future _cleanUp() async { _internetLostDelayTimer?.cancel(); _internetStatusListener?.cancel(); - await _channel?.sink.close(status.goingAway); + await _channel?.sink.close(socket_channel_status.goingAway); } @override @@ -218,12 +221,12 @@ class PureSocket implements IPureSocket { case InternetStatus.disconnected: var that = this; _internetLostDelayTimer?.cancel(); - _internetLostDelayTimer = Timer(const Duration(seconds: 60), () { + _internetLostDelayTimer = Timer(const Duration(seconds: 60), () async { if (_internetStatus != InternetStatus.disconnected) { return; } - that.disconnect(); + await that.disconnect(); _listener?.onInternetConnectionFailed(); }); @@ -239,30 +242,39 @@ abstract interface class ITransctipSegmentSocketServiceListener { void onClosed(); } +class SpeechProfileTranscripSegmentSocketService extends TranscripSegmentSocketService { + SpeechProfileTranscripSegmentSocketService.create(super.sampleRate, super.codec) + : super.create(includeSpeechProfile: false, newMemoryWatch: false); +} + +class MemoryTranscripSegmentSocketService extends TranscripSegmentSocketService { + MemoryTranscripSegmentSocketService.create(super.sampleRate, super.codec) + : super.create(includeSpeechProfile: true, newMemoryWatch: true); +} + +enum SocketServiceState { + connected, + disconnected, +} + class TranscripSegmentSocketService implements IPureSocketListener { late PureSocket _socket; final Map _listeners = {}; + SocketServiceState get state => + _socket.status == PureSocketStatus.connected ? SocketServiceState.connected : SocketServiceState.disconnected; + int sampleRate; - String codec; + BleAudioCodec codec; bool includeSpeechProfile; - - factory TranscripSegmentSocketService() { - if (_instance == null) { - throw Exception("TranscripSegmentSocketService is not initiated"); - } - - return _instance!; - } - - /// The singleton instance of [TranscripSegmentSocketService]. - static TranscripSegmentSocketService? _instance; + bool newMemoryWatch; TranscripSegmentSocketService.create( this.sampleRate, - this.codec, - this.includeSpeechProfile, - ) { + this.codec, { + this.includeSpeechProfile = false, + this.newMemoryWatch = true, + }) { final recordingsLanguage = SharedPreferencesUtil().recordingsLanguage; var params = '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}&include_speech_profile=$includeSpeechProfile'; @@ -272,14 +284,6 @@ class TranscripSegmentSocketService implements IPureSocketListener { _socket.setListener(this); } - TranscripSegmentSocketService.createInstance( - this.sampleRate, - this.codec, - this.includeSpeechProfile, - ) { - _instance = TranscripSegmentSocketService.createInstance(sampleRate, codec, includeSpeechProfile); - } - void subscribe(Object context, ITransctipSegmentSocketServiceListener listener) { if (_listeners.containsKey(context.hashCode)) { _listeners.remove(context.hashCode); @@ -293,13 +297,25 @@ class TranscripSegmentSocketService implements IPureSocketListener { } } - void start() { - _socket.connect(); + Future start() async { + bool ok = await _socket.connect(); + if (!ok) { + debugPrint("Can not connect to websocket"); + } } - void stop() { - _socket.disconnect(); + Future stop({String? reason}) async { + await _socket.disconnect(); _listeners.clear(); + + if (reason != null) { + debugPrint(reason); + } + } + + Future send(dynamic message) async { + _socket.send(message); + return; } @override From c0f87f1b575cb4639e8acd075629fdda6304fd9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Thu, 19 Sep 2024 11:53:00 +0700 Subject: [PATCH 18/88] Integrate capture provider with pure socket #2 - no speech profile, no tight chains on providers --- app/lib/main.dart | 2 +- app/lib/pages/capture/_page.dart | 265 +----------------- app/lib/pages/capture/widgets/widgets.dart | 3 +- app/lib/pages/home/page.dart | 3 +- app/lib/pages/memories/widgets/capture.dart | 19 +- .../memories/widgets/processing_capture.dart | 8 +- app/lib/providers/capture_provider.dart | 101 ++++--- app/lib/providers/device_provider.dart | 25 +- .../providers/speech_profile_provider.dart | 5 +- app/lib/services/device_connections.dart | 21 +- app/lib/services/sockets.dart | 37 ++- app/lib/utils/pure_socket.dart | 32 ++- 12 files changed, 142 insertions(+), 379 deletions(-) diff --git a/app/lib/main.dart b/app/lib/main.dart index 3aedcc952..4e0069187 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -173,7 +173,7 @@ class _MyAppState extends State with WidgetsBindingObserver { ChangeNotifierProxyProvider2( create: (context) => DeviceProvider(), update: (BuildContext context, captureProvider, wsProvider, DeviceProvider? previous) => - (previous?..setProviders(captureProvider, wsProvider)) ?? DeviceProvider(), + (previous?..setProviders(captureProvider)) ?? DeviceProvider(), ), ChangeNotifierProxyProvider( create: (context) => OnboardingProvider(), diff --git a/app/lib/pages/capture/_page.dart b/app/lib/pages/capture/_page.dart index b392425d0..75d10b227 100644 --- a/app/lib/pages/capture/_page.dart +++ b/app/lib/pages/capture/_page.dart @@ -1,22 +1,6 @@ import 'package:flutter/material.dart'; -import 'package:flutter/scheduler.dart'; -import 'package:flutter_foreground_task/flutter_foreground_task.dart'; -import 'package:flutter_provider_utilities/flutter_provider_utilities.dart'; -import 'package:friend_private/backend/schema/bt_device.dart'; -import 'package:friend_private/backend/schema/geolocation.dart'; -import 'package:friend_private/pages/capture/widgets/widgets.dart'; -import 'package:friend_private/providers/capture_provider.dart'; -import 'package:friend_private/providers/connectivity_provider.dart'; -import 'package:friend_private/providers/device_provider.dart'; -import 'package:friend_private/providers/onboarding_provider.dart'; -import 'package:friend_private/utils/audio/wav_bytes.dart'; -import 'package:friend_private/utils/ble/communication.dart'; -import 'package:friend_private/utils/enums.dart'; -import 'package:friend_private/widgets/dialog.dart'; -import 'package:provider/provider.dart'; - -import '../../providers/websocket_provider.dart'; +@Deprecated("Capture page is deprecated, use @pages > memories > widgets > capture instead.") class CapturePage extends StatefulWidget { const CapturePage({ super.key, @@ -26,252 +10,9 @@ class CapturePage extends StatefulWidget { State createState() => CapturePageState(); } -class CapturePageState extends State with AutomaticKeepAliveClientMixin, WidgetsBindingObserver { - @override - bool get wantKeepAlive => true; - - /// ---- - - // List segments = List.filled(100, '') - // .mapIndexed((i, e) => TranscriptSegment( - // text: - // '''[00:00:00 - 00:02:23] Speaker 0: The tech giants already know these techniques. - // My goal is to unlock their secrets for the benefit of businesses who to design and help users develop healthy habits. - // To that end, there's so much I wanted to put in this book that just didn't fit. Before you reading, please take a moment to download these - // supplementary materials included free with the purchase of this audiobook. Please go to nirandfar.com forward slash hooked. - // Near is spelled like my first name, speck, n I r. Andfar.com/hooked. There you will find the hooked model workbook, an ebook of case studies, - // and a free email course about product psychology. Also, if you'd like to connect with me, you can reach me through my blog at nirafar.com. - // You can schedule office hours to discuss your questions. Look forward to hearing from you as you build habits for good. - // - // Introduction. 79% of smartphone owners check their device within 15 minutes of waking up every morning. Perhaps most startling, - // fully 1 third of Americans say they would rather give up sex than lose their cell phones. A 2011 university study suggested people check their - // phones 34 times per day. However, industry insiders believe that number is closer to an astounding 150 daily sessions. We are hooked. - // It's the poll to visit YouTube, Facebook, or Twitter for just a few minutes only to find yourself still capping and scrolling an hour later. - // It's the urge you likely feel throughout your day but hardly notice. Cognitive psychologists define habits as, quote, automatic behaviors triggered - // by situational cues. Things we do with little or no conscious thought. The products and services we use habitually alter our everyday behavior. - // Just as their designers intended. Our actions have been engineered. How do companies producing little more than bits of code displayed on a screen - // seemingly control users' minds? What makes some products so habit forming? Forming habit is imperative for the survival of many products. - // - // As infinite distractions compete for our attention, companies are learning to master novel tactics that stay relevant in users' minds. - // Amassing millions of users is no longer good enough. Companies increasingly find that their economic value is a function of the strength of the habits they create. - // - // In order to win the loyalty of their users and create a product that's regularly used, companies must learn not only what compels users to click, - // but also what makes them click. Although some companies are just waking up to this new reality, others are already cashing in. By mastering habit - // forming product design, companies profiles in this book make their goods indispensable. First to mind wins. Companies that form strong user habits enjoy - // several benefits to their bottom line. These companies attach their product to internal triggers. A result, users show up without any external prompting. - // Instead of relying on expensive marketing, how did forming companies link their services to users' daily routines and emotions. - // A habit is at work when users feel a tad bored and instantly open Twitter. Feel a hang of loneliness, and before rational thought occurs, - // they're scrolling through their Facebook feeds.''', - // speaker: 'SPEAKER_0${i % 2}', - // isUser: false, - // start: 0, - // end: 10, - // )) - // .toList(); - - setHasTranscripts(bool hasTranscripts) { - context.read().setHasTranscripts(hasTranscripts); - } - - void _onReceiveTaskData(dynamic data) { - if (data is Map) { - if (data.containsKey('latitude') && data.containsKey('longitude')) { - context.read().setGeolocation(Geolocation( - latitude: data['latitude'], - longitude: data['longitude'], - accuracy: data['accuracy'], - altitude: data['altitude'], - time: DateTime.parse(data['time']), - )); - } else { - if (mounted) { - context.read().setGeolocation(null); - } - } - } - } - - @override - void initState() { - WavBytesUtil.clearTempWavFiles(); - - FlutterForegroundTask.addTaskDataCallback(_onReceiveTaskData); - WidgetsBinding.instance.addObserver(this); - SchedulerBinding.instance.addPostFrameCallback((_) async { - // await context.read().processCachedTranscript(); - if (context.read().connectedDevice != null) { - context.read().stopFindDeviceTimer(); - } - // if (await LocationService().displayPermissionsDialog()) { - // await showDialog( - // context: context, - // builder: (c) => getDialog( - // context, - // () => Navigator.of(context).pop(), - // () async { - // await requestLocationPermission(); - // await LocationService().requestBackgroundPermission(); - // if (mounted) Navigator.of(context).pop(); - // }, - // 'Enable Location? 🌍', - // 'Allow location access to tag your memories. Set to "Always Allow" in Settings', - // singleButton: false, - // okButtonText: 'Continue', - // ), - // ); - // } - final connectivityProvider = Provider.of(context, listen: false); - if (!connectivityProvider.isConnected) { - context.read().cancelMemoryCreationTimer(); - } - }); - - super.initState(); - } - - @override - void dispose() { - WidgetsBinding.instance.removeObserver(this); - // context.read().closeWebSocket(); - super.dispose(); - } - - // Future requestLocationPermission() async { - // LocationService locationService = LocationService(); - // bool serviceEnabled = await locationService.enableService(); - // if (!serviceEnabled) { - // debugPrint('Location service not enabled'); - // if (mounted) { - // ScaffoldMessenger.of(context).showSnackBar( - // const SnackBar( - // content: Text( - // 'Location services are disabled. Enable them for a better experience.', - // style: TextStyle(color: Colors.white, fontSize: 14), - // ), - // ), - // ); - // } - // } else { - // PermissionStatus permissionGranted = await locationService.requestPermission(); - // SharedPreferencesUtil().locationEnabled = permissionGranted == PermissionStatus.granted; - // MixpanelManager().setUserProperty('Location Enabled', SharedPreferencesUtil().locationEnabled); - // if (permissionGranted == PermissionStatus.denied) { - // debugPrint('Location permission not granted'); - // } else if (permissionGranted == PermissionStatus.deniedForever) { - // debugPrint('Location permission denied forever'); - // if (mounted) { - // ScaffoldMessenger.of(context).showSnackBar( - // const SnackBar( - // content: Text( - // 'If you change your mind, you can enable location services in your device settings.', - // style: TextStyle(color: Colors.white, fontSize: 14), - // ), - // ), - // ); - // } - // } - // } - // } - +class CapturePageState extends State { @override Widget build(BuildContext context) { - super.build(context); - return Consumer2(builder: (context, provider, deviceProvider, child) { - return MessageListener( - showInfo: (info) { - // This probably will never be called because this has been handled even before we start the audio stream. But it's here just in case. - if (info == 'FIM_CHANGE') { - showDialog( - context: context, - barrierDismissible: false, - builder: (c) => getDialog( - context, - () async { - context.read().closeWebSocketWithoutReconnect('Firmware change detected'); - var connectedDevice = deviceProvider.connectedDevice; - var codec = await getAudioCodec(connectedDevice!.id); - context.read().resetState(restartBytesProcessing: true); - context.read().initiateWebsocket(codec); - if (Navigator.canPop(context)) { - Navigator.pop(context); - } - }, - () => {}, - 'Firmware change detected!', - 'You are currently using a different firmware version than the one you were using before. Please restart the app to apply the changes.', - singleButton: true, - okButtonText: 'Restart', - ), - ); - } - }, - showError: (error) { - ScaffoldMessenger.of(context).showSnackBar( - SnackBar( - content: Text( - error, - style: const TextStyle(color: Colors.white, fontSize: 14), - ), - ), - ); - }, - child: Stack( - children: [ - ListView(children: [ - SpeechProfileCardWidget(), - ...getConnectionStateWidgets( - context, - provider.hasTranscripts, - deviceProvider.connectedDevice, - context.read().wsConnectionState, - ), - getTranscriptWidget( - provider.memoryCreating, - provider.segments, - provider.photos, - deviceProvider.connectedDevice, - ), - ...connectionStatusWidgets( - context, - provider.segments, - context.read().wsConnectionState, - ), - const SizedBox(height: 16) - ]), - getPhoneMicRecordingButton(() => _recordingToggled(provider), provider.recordingState), - ], - ), - ); - }); - } - - _recordingToggled(CaptureProvider provider) async { - var recordingState = provider.recordingState; - if (recordingState == RecordingState.record) { - provider.stopStreamRecording(); - provider.updateRecordingState(RecordingState.stop); - context.read().cancelMemoryCreationTimer(); - // await context.read().tryCreateMemoryManually(); - } else if (recordingState == RecordingState.initialising) { - debugPrint('initialising, have to wait'); - } else { - showDialog( - context: context, - builder: (c) => getDialog( - context, - () => Navigator.pop(context), - () async { - provider.updateRecordingState(RecordingState.initialising); - context.read().closeWebSocketWithoutReconnect('Recording with phone mic'); - await provider.initiateWebsocket(BleAudioCodec.pcm16, 16000); - await provider.streamRecording(); - Navigator.pop(context); - }, - 'Limited Capabilities', - 'Recording with your phone microphone has a few limitations, including but not limited to: speaker profiles, background reliability.', - okButtonText: 'Ok, I understand', - ), - ); - } + return const Text("Depreacted"); } } diff --git a/app/lib/pages/capture/widgets/widgets.dart b/app/lib/pages/capture/widgets/widgets.dart index 0cdf0da9f..55f0f162c 100644 --- a/app/lib/pages/capture/widgets/widgets.dart +++ b/app/lib/pages/capture/widgets/widgets.dart @@ -228,7 +228,8 @@ class SpeechProfileCardWidget extends StatelessWidget { if (hasSpeakerProfile != SharedPreferencesUtil().hasSpeakerProfile) { if (context.mounted) { // TODO: is the websocket restarting once the user comes back? - context.read().restartWebSocket(); + // TODO: thinh, socket speech profile + // context.read().restartWebSocket(); } } }, diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index 81655b6d1..f92ff0f28 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -544,7 +544,8 @@ class _HomePageState extends State with WidgetsBindingObserver, Ticker if (language != SharedPreferencesUtil().recordingsLanguage || hasSpeech != SharedPreferencesUtil().hasSpeakerProfile || transcriptModel != SharedPreferencesUtil().transcriptionModel) { - context.read().restartWebSocket(); + // TODO: thinh, socket speech profile + // context.read().restartWebSocket(); } }, ), diff --git a/app/lib/pages/memories/widgets/capture.dart b/app/lib/pages/memories/widgets/capture.dart index d94468dbc..f81497050 100644 --- a/app/lib/pages/memories/widgets/capture.dart +++ b/app/lib/pages/memories/widgets/capture.dart @@ -34,13 +34,15 @@ class LiteCaptureWidgetState extends State void _onReceiveTaskData(dynamic data) { if (data is Map) { if (data.containsKey('latitude') && data.containsKey('longitude')) { - context.read().setGeolocation(Geolocation( - latitude: data['latitude'], - longitude: data['longitude'], - accuracy: data['accuracy'], - altitude: data['altitude'], - time: DateTime.parse(data['time']), - )); + if (mounted) { + context.read().setGeolocation(Geolocation( + latitude: data['latitude'], + longitude: data['longitude'], + accuracy: data['accuracy'], + altitude: data['altitude'], + time: DateTime.parse(data['time']), + )); + } } else { if (mounted) { context.read().setGeolocation(null); @@ -139,8 +141,7 @@ class LiteCaptureWidgetState extends State context.read().closeWebSocketWithoutReconnect('Firmware change detected'); var connectedDevice = deviceProvider.connectedDevice; var codec = await _getAudioCodec(connectedDevice!.id); - context.read().resetState(restartBytesProcessing: true); - context.read().initiateWebsocket(codec); + await context.read().changeAudioRecordProfile(codec); if (Navigator.canPop(context)) { Navigator.pop(context); } diff --git a/app/lib/pages/memories/widgets/processing_capture.dart b/app/lib/pages/memories/widgets/processing_capture.dart index 2654fad4b..3142af71f 100644 --- a/app/lib/pages/memories/widgets/processing_capture.dart +++ b/app/lib/pages/memories/widgets/processing_capture.dart @@ -7,7 +7,6 @@ import 'package:friend_private/pages/memory_capturing/page.dart'; import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; import 'package:friend_private/utils/enums.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -80,7 +79,7 @@ class _MemoryCaptureWidgetState extends State { _toggleRecording(BuildContext context, CaptureProvider provider) async { var recordingState = provider.recordingState; if (recordingState == RecordingState.record) { - provider.stopStreamRecording(); + await provider.stopStreamRecording(); context.read().cancelMemoryCreationTimer(); await context.read().createMemory(); MixpanelManager().phoneMicRecordingStopped(); @@ -95,8 +94,9 @@ class _MemoryCaptureWidgetState extends State { () async { Navigator.pop(context); provider.updateRecordingState(RecordingState.initialising); - context.read().closeWebSocketWithoutReconnect('Recording with phone mic'); - await provider.initiateWebsocket(BleAudioCodec.pcm16, 16000); + // TODO: thinh, socket check why we need to close socket provider here, disable temporary + //context.read().closeWebSocketWithoutReconnect('Recording with phone mic'); + await provider.changeAudioRecordProfile(BleAudioCodec.pcm16, 16000); await provider.streamRecording(); MixpanelManager().phoneMicRecordingStarted(); }, diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 4a2bca967..9306e8805 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -20,6 +20,7 @@ import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/pages/capture/logic/openglass_mixin.dart'; import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/providers/message_provider.dart'; +import 'package:friend_private/services/devices.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/analytics/growthbook.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; @@ -85,14 +86,8 @@ class CaptureProvider extends ChangeNotifier String? processingMemoryId; - bool resetStateAlreadyCalled = false; String dateTimeStorageString = ""; - void setResetStateAlreadyCalled(bool value) { - resetStateAlreadyCalled = value; - notifyListeners(); - } - void setHasTranscripts(bool value) { hasTranscripts = value; notifyListeners(); @@ -120,7 +115,7 @@ class CaptureProvider extends ChangeNotifier notifyListeners(); } - void updateConnectedDevice(BTDeviceStruct? device) { + void _updateConnectedDevice(BTDeviceStruct? device) { debugPrint('connected device changed from ${connectedDevice?.id} to ${device?.id}'); connectedDevice = device; notifyListeners(); @@ -294,7 +289,7 @@ class CaptureProvider extends ChangeNotifier return true; } - void _cleanNew() async { + Future _clean() async { segments = []; audioStorage?.clearAudioBytes(); @@ -307,11 +302,15 @@ class CaptureProvider extends ChangeNotifier photos = []; conversationId = const Uuid().v4(); processingMemoryId = null; + } + + Future _cleanNew() async { + _clean(); // Create new socket session // Warn: should have a better solution to keep the socket alived - await _socket?.stop(reason: 'reset new memory session'); - await initiateWebsocket(); + debugPrint("_cleanNew"); + await _initiateWebsocket(force: true); } _handleCalendarCreation(ServerMemory memory) { @@ -334,10 +333,20 @@ class CaptureProvider extends ChangeNotifier } } - Future initiateWebsocket([ + Future changeAudioRecordProfile([ BleAudioCodec? audioCodec, int? sampleRate, ]) async { + debugPrint("changeAudioRecordProfile"); + await _resetState(restartBytesProcessing: true); + await _initiateWebsocket(audioCodec: audioCodec, sampleRate: sampleRate); + } + + Future _initiateWebsocket({ + BleAudioCodec? audioCodec, + int? sampleRate, + bool force = false, + }) async { debugPrint('initiateWebsocket in capture_provider'); BleAudioCodec codec = audioCodec ?? SharedPreferencesUtil().deviceCodec; @@ -346,7 +355,7 @@ class CaptureProvider extends ChangeNotifier debugPrint('is ws null: ${_socket == null}'); // TODO: thinh, socket - _socket = await ServiceManager.instance().socket.memory(codec: codec, sampleRate: sampleRate); + _socket = await ServiceManager.instance().socket.memory(codec: codec, sampleRate: sampleRate, force: force); if (_socket == null) { throw Exception("Can not create new memory socket"); } @@ -456,35 +465,20 @@ class CaptureProvider extends ChangeNotifier notifyListeners(); } - Future resetState({ + Future _resetState({ bool restartBytesProcessing = true, - bool isFromSpeechProfile = false, - BTDeviceStruct? btDevice, }) async { - if (resetStateAlreadyCalled) { - debugPrint('resetState already called'); - return; - } - setResetStateAlreadyCalled(true); - debugPrint('resetState: restartBytesProcessing=$restartBytesProcessing, isFromSpeechProfile=$isFromSpeechProfile'); + debugPrint('resetState: restartBytesProcessing=$restartBytesProcessing'); _cleanupCurrentState(); await startOpenGlass(); - if (!isFromSpeechProfile) { - await _handleMemoryCreation(restartBytesProcessing); - } + await _handleMemoryCreation(restartBytesProcessing); - bool codecChanged = await _checkCodecChange(); - - if (restartBytesProcessing || codecChanged) { - await _manageWebSocketConnection(codecChanged, isFromSpeechProfile); - } + await _ensureSocketConnection(force: true); - await initiateFriendAudioStreaming(isFromSpeechProfile); + await _initiateFriendAudioStreaming(); // TODO: Commenting this for now as DevKit 2 is not yet used in production // await initiateStorageBytesStreaming(); - - setResetStateAlreadyCalled(false); notifyListeners(); } @@ -573,16 +567,16 @@ class CaptureProvider extends ChangeNotifier return false; } - Future _manageWebSocketConnection(bool codecChanged, bool isFromSpeechProfile) async { - if (codecChanged || _socket?.state != SocketServiceState.connected) { - await _socket?.stop(reason: 'reset state $isFromSpeechProfile'); - // if (!isFromSpeechProfile) { - await initiateWebsocket(); - // } + Future _ensureSocketConnection({bool force = false}) async { + debugPrint("_ensureSocketConnection"); + var codec = SharedPreferencesUtil().deviceCodec; + if (force || (codec != _socket?.codec || _socket?.state != SocketServiceState.connected)) { + await _socket?.stop(reason: 'reset state, force $force'); + await _initiateWebsocket(force: force); } } - Future initiateFriendAudioStreaming(bool isFromSpeechProfile) async { + Future _initiateFriendAudioStreaming() async { debugPrint('connectedDevice: $connectedDevice in initiateFriendAudioStreaming'); if (connectedDevice == null) return; @@ -591,7 +585,7 @@ class CaptureProvider extends ChangeNotifier debugPrint('Device codec changed from ${SharedPreferencesUtil().deviceCodec} to $codec'); SharedPreferencesUtil().deviceCodec = codec; notifyInfo('FIM_CHANGE'); - await _manageWebSocketConnection(true, isFromSpeechProfile); + await _ensureSocketConnection(); } // Why is the connectedDevice null at this point? @@ -670,11 +664,36 @@ class CaptureProvider extends ChangeNotifier ServiceManager.instance().mic.stop(); } + Future streamDeviceRecording({ + BTDeviceStruct? btDevice, + bool restartBytesProcessing = true, + }) async { + debugPrint("streamDeviceRecording ${btDevice} ${restartBytesProcessing}"); + if (btDevice != null) { + _updateConnectedDevice(btDevice); + } + await _resetState( + restartBytesProcessing: restartBytesProcessing, + ); + } + + Future stopStreamDeviceRecording() async { + _updateConnectedDevice(null); + await _resetState(); + } + // Socket handling @override void onClosed() { - debugPrint('socket is closed'); + debugPrint('[Provider] Socket is closed'); + + _clean(); + + // Notify + setMemoryCreating(false); + setHasTranscripts(false); + notifyListeners(); } @override diff --git a/app/lib/providers/device_provider.dart b/app/lib/providers/device_provider.dart index 0494f8d59..634f13e32 100644 --- a/app/lib/providers/device_provider.dart +++ b/app/lib/providers/device_provider.dart @@ -3,7 +3,6 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/services/devices.dart'; import 'package:friend_private/services/notification_service.dart'; @@ -13,7 +12,6 @@ import 'package:instabug_flutter/instabug_flutter.dart'; class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption { CaptureProvider? captureProvider; - WebSocketProvider? webSocketProvider; bool isConnecting = false; bool isConnected = false; @@ -25,16 +23,14 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption Timer? _disconnectNotificationTimer; - void setProviders(CaptureProvider provider, WebSocketProvider wsProvider) { + void setProviders(CaptureProvider provider) { captureProvider = provider; - webSocketProvider = wsProvider; notifyListeners(); } void setConnectedDevice(BTDeviceStruct? device) { connectedDevice = device; print('setConnectedDevice: $device'); - captureProvider?.updateConnectedDevice(device); notifyListeners(); } @@ -149,25 +145,10 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption if (isConnected) { await initiateBleBatteryListener(); } - await captureProvider?.resetState(restartBytesProcessing: true, btDevice: connectedDevice); - // if (captureProvider?.webSocketConnected == false) { - // restartWebSocket(); - // } notifyListeners(); } - Future restartWebSocket() async { - debugPrint('restartWebSocket'); - - await webSocketProvider?.closeWebSocketWithoutReconnect('Restarting WebSocket'); - if (connectedDevice == null) { - return; - } - await captureProvider?.resetState(restartBytesProcessing: true); - notifyListeners(); - } - void updateConnectingStatus(bool value) { isConnecting = value; notifyListeners(); @@ -197,7 +178,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(null); setIsConnected(false); updateConnectingStatus(false); - await captureProvider?.resetState(restartBytesProcessing: false); + await captureProvider?.stopStreamDeviceRecording(); captureProvider?.setAudioBytesConnected(false); print('after resetState inside initiateConnectionListener'); @@ -224,7 +205,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(device); setIsConnected(true); updateConnectingStatus(false); - await captureProvider?.resetState(restartBytesProcessing: true, btDevice: connectedDevice); + await captureProvider?.streamDeviceRecording(restartBytesProcessing: true, btDevice: connectedDevice!); // initiateBleBatteryListener(); // The device is still disconnected for some reason if (connectedDevice != null) { diff --git a/app/lib/providers/speech_profile_provider.dart b/app/lib/providers/speech_profile_provider.dart index 7e2acaf5c..ecacc5946 100644 --- a/app/lib/providers/speech_profile_provider.dart +++ b/app/lib/providers/speech_profile_provider.dart @@ -204,7 +204,7 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp await createMemory(); captureProvider?.clearTranscripts(); } - await captureProvider?.resetState(restartBytesProcessing: true); + await captureProvider?.streamDeviceRecording(restartBytesProcessing: true); uploadingProfile = false; profileCompleted = true; text = ''; @@ -297,7 +297,8 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp uploadingProfile = false; profileCompleted = false; await webSocketProvider?.closeWebSocketWithoutReconnect('closing'); - await captureProvider?.resetState(restartBytesProcessing: true, isFromSpeechProfile: true); + // TODO: thinh, socket check why? disable temporary + // await captureProvider?.resetState(restartBytesProcessing: true, isFromSpeechProfile: true); notifyListeners(); } diff --git a/app/lib/services/device_connections.dart b/app/lib/services/device_connections.dart index 19c1d828f..aab1fef8f 100644 --- a/app/lib/services/device_connections.dart +++ b/app/lib/services/device_connections.dart @@ -45,6 +45,8 @@ abstract class DeviceConnection { DeviceConnectionState get connectionState => _connectionState; + Function(String deviceId, DeviceConnectionState state)? _connectionStateChangedCallback; + DateTime? get pongAt => _pongAt; late StreamSubscription _connectionStateSubscription; @@ -62,8 +64,9 @@ abstract class DeviceConnection { } // Connect + _connectionStateChangedCallback = onConnectionStateChanged; _connectionStateSubscription = bleDevice.connectionState.listen((BluetoothConnectionState state) async { - _onBleConnectionStateChanged(state, onConnectionStateChanged); + _onBleConnectionStateChanged(state); }); await FlutterBluePlus.adapterState.where((val) => val == BluetoothAdapterState.on).first; @@ -82,26 +85,26 @@ abstract class DeviceConnection { _services = await bleDevice.discoverServices(); } - void _onBleConnectionStateChanged( - BluetoothConnectionState state, Function(String deviceId, DeviceConnectionState state)? callback) async { + void _onBleConnectionStateChanged(BluetoothConnectionState state) async { if (state == BluetoothConnectionState.disconnected && _connectionState == DeviceConnectionState.connected) { _connectionState = DeviceConnectionState.disconnected; - await disconnect(callback: callback); + await disconnect(); return; } if (state == BluetoothConnectionState.connected && _connectionState == DeviceConnectionState.disconnected) { _connectionState = DeviceConnectionState.connected; - if (callback != null) { - callback(device.id, _connectionState); + if (_connectionStateChangedCallback != null) { + _connectionStateChangedCallback!(device.id, _connectionState); } } } - Future disconnect({Function(String deviceId, DeviceConnectionState state)? callback}) async { + Future disconnect() async { _connectionState = DeviceConnectionState.disconnected; - if (callback != null) { - callback(device.id, _connectionState); + if (_connectionStateChangedCallback != null) { + _connectionStateChangedCallback!(device.id, _connectionState); + _connectionStateChangedCallback = null; } await bleDevice.disconnect(); _connectionStateSubscription.cancel(); diff --git a/app/lib/services/sockets.dart b/app/lib/services/sockets.dart index 0fffecf57..cf6fa7510 100644 --- a/app/lib/services/sockets.dart +++ b/app/lib/services/sockets.dart @@ -1,3 +1,4 @@ +import 'package:flutter/material.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/utils/pure_socket.dart'; @@ -5,7 +6,8 @@ abstract class ISocketService { void start(); void stop(); - Future memory({required BleAudioCodec codec, required int sampleRate, bool force = false}); + Future memory( + {required BleAudioCodec codec, required int sampleRate, bool force = false}); TranscripSegmentSocketService speechProfile(); } @@ -36,22 +38,31 @@ class SocketServicePool extends ISocketService { } memoryMutex = true; - if (!force && - _memory?.codec == codec && - _memory?.sampleRate == sampleRate && - _memory?.state == SocketServiceState.connected) { - return _memory; - } + debugPrint("socket memory > $codec $sampleRate $force"); - // new socket - await _memory?.stop(); + try { + if (!force && + _memory?.codec == codec && + _memory?.sampleRate == sampleRate && + _memory?.state == SocketServiceState.connected) { + return _memory; + } - _memory = MemoryTranscripSegmentSocketService.create(sampleRate, codec); - await _memory?.start(); + // new socket + await _memory?.stop(); - memoryMutex = false; + _memory = MemoryTranscripSegmentSocketService.create(sampleRate, codec); + await _memory?.start(); + if (_memory?.state != SocketServiceState.connected) { + return null; + } + + return _memory; + } finally { + memoryMutex = false; + } - return _memory; + return null; } @override diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 70dbe3a8e..461f4c74c 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -53,14 +53,20 @@ class PureCore { PureCore.createInstance() { internetConnection = InternetConnection.createInstance( + /* customCheckOptions: [ InternetCheckOption( - uri: Uri.parse(Env.apiBaseUrl!), - timeout: const Duration( - seconds: 30, - )), + uri: Uri.parse(Env.apiBaseUrl!), + timeout: const Duration( + seconds: 30, + ), + responseStatusFn: (resp) { + return resp.statusCode < 500; + }, + ), ], - ); + */ + ); } } @@ -141,13 +147,14 @@ class PureSocket implements IPureSocket { @override Future disconnect() async { _status = PureSocketStatus.disconnected; + onClosed(); await _cleanUp(); } Future _cleanUp() async { _internetLostDelayTimer?.cancel(); _internetStatusListener?.cancel(); - await _channel?.sink.close(socket_channel_status.goingAway); + await _channel?.sink.close(socket_channel_status.normalClosure); } @override @@ -277,7 +284,7 @@ class TranscripSegmentSocketService implements IPureSocketListener { }) { final recordingsLanguage = SharedPreferencesUtil().recordingsLanguage; var params = - '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}&include_speech_profile=$includeSpeechProfile'; + '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}&include_speech_profile=$includeSpeechProfile&new_memory_watch=$newMemoryWatch'; String url = '${Env.apiBaseUrl!.replaceAll('https', 'wss')}listen$params'; _socket = PureSocket(url); @@ -285,16 +292,12 @@ class TranscripSegmentSocketService implements IPureSocketListener { } void subscribe(Object context, ITransctipSegmentSocketServiceListener listener) { - if (_listeners.containsKey(context.hashCode)) { - _listeners.remove(context.hashCode); - } + _listeners.remove(context.hashCode); _listeners.putIfAbsent(context.hashCode, () => listener); } void unsubscribe(Object context) { - if (_listeners.containsKey(context.hashCode)) { - _listeners.remove(context.hashCode); - } + _listeners.remove(context.hashCode); } Future start() async { @@ -334,6 +337,7 @@ class TranscripSegmentSocketService implements IPureSocketListener { @override void onMessage(event) { + debugPrint("[TranscriptSegmentService] onMessage ${event}"); if (event == 'ping') return; // Decode json @@ -351,7 +355,7 @@ class TranscripSegmentSocketService implements IPureSocketListener { // Transcript segments if (jsonEvent is List) { var segments = jsonEvent; - if (segments.isNotEmpty) { + if (segments.isEmpty) { return; } _listeners.forEach((k, v) { From a9ccb5b9942a08b80a197d4b325724c5329e03c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Thu, 19 Sep 2024 12:21:22 +0700 Subject: [PATCH 19/88] RecordingDevice instead of connected device for capture provider --- app/lib/providers/capture_provider.dart | 54 ++++++++++++------------- app/lib/providers/device_provider.dart | 2 +- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 9306e8805..29ec61883 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -20,7 +20,6 @@ import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/pages/capture/logic/openglass_mixin.dart'; import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/providers/message_provider.dart'; -import 'package:friend_private/services/devices.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/analytics/growthbook.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; @@ -48,7 +47,7 @@ class CaptureProvider extends ChangeNotifier notifyListeners(); } - BTDeviceStruct? connectedDevice; + BTDeviceStruct? _recordingDevice; bool isGlasses = false; List segments = []; @@ -115,9 +114,9 @@ class CaptureProvider extends ChangeNotifier notifyListeners(); } - void _updateConnectedDevice(BTDeviceStruct? device) { - debugPrint('connected device changed from ${connectedDevice?.id} to ${device?.id}'); - connectedDevice = device; + void _updateRecordingDevice(BTDeviceStruct? device) { + debugPrint('connected device changed from ${_recordingDevice?.id} to ${device?.id}'); + _recordingDevice = device; notifyListeners(); } @@ -442,13 +441,13 @@ class CaptureProvider extends ChangeNotifier Future getFileFromDevice(int fileNum) async { storageUtil.fileNum = fileNum; int command = 0; - writeToStorage(connectedDevice!.id, storageUtil.fileNum, command); + writeToStorage(_recordingDevice!.id, storageUtil.fileNum, command); } Future clearFileFromDevice(int fileNum) async { storageUtil.fileNum = fileNum; int command = 1; - writeToStorage(connectedDevice!.id, storageUtil.fileNum, command); + writeToStorage(_recordingDevice!.id, storageUtil.fileNum, command); } void clearTranscripts() { @@ -471,11 +470,11 @@ class CaptureProvider extends ChangeNotifier debugPrint('resetState: restartBytesProcessing=$restartBytesProcessing'); _cleanupCurrentState(); - await startOpenGlass(); await _handleMemoryCreation(restartBytesProcessing); await _ensureSocketConnection(force: true); + await startOpenGlass(); await _initiateFriendAudioStreaming(); // TODO: Commenting this for now as DevKit 2 is not yet used in production // await initiateStorageBytesStreaming(); @@ -556,8 +555,8 @@ class CaptureProvider extends ChangeNotifier } Future _checkCodecChange() async { - if (connectedDevice != null) { - BleAudioCodec newCodec = await _getAudioCodec(connectedDevice!.id); + if (_recordingDevice != null) { + BleAudioCodec newCodec = await _getAudioCodec(_recordingDevice!.id); if (SharedPreferencesUtil().deviceCodec != newCodec) { debugPrint('Device codec changed from ${SharedPreferencesUtil().deviceCodec} to $newCodec'); SharedPreferencesUtil().deviceCodec = newCodec; @@ -577,10 +576,10 @@ class CaptureProvider extends ChangeNotifier } Future _initiateFriendAudioStreaming() async { - debugPrint('connectedDevice: $connectedDevice in initiateFriendAudioStreaming'); - if (connectedDevice == null) return; + debugPrint('_recordingDevice: $_recordingDevice in initiateFriendAudioStreaming'); + if (_recordingDevice == null) return; - BleAudioCodec codec = await _getAudioCodec(connectedDevice!.id); + BleAudioCodec codec = await _getAudioCodec(_recordingDevice!.id); if (SharedPreferencesUtil().deviceCodec != codec) { debugPrint('Device codec changed from ${SharedPreferencesUtil().deviceCodec} to $codec'); SharedPreferencesUtil().deviceCodec = codec; @@ -588,10 +587,10 @@ class CaptureProvider extends ChangeNotifier await _ensureSocketConnection(); } - // Why is the connectedDevice null at this point? + // Why is the _recordingDevice null at this point? if (!audioBytesConnected) { - if (connectedDevice != null) { - await streamAudioToWs(connectedDevice!.id, codec); + if (_recordingDevice != null) { + await streamAudioToWs(_recordingDevice!.id, codec); } else { // Is the app in foreground when this happens? Logger.handle(Exception('Device Not Connected'), StackTrace.current, @@ -604,18 +603,18 @@ class CaptureProvider extends ChangeNotifier Future initiateStorageBytesStreaming() async { debugPrint('initiateStorageBytesStreaming'); - if (connectedDevice == null) return; - currentStorageFiles = await _getStorageList(connectedDevice!.id); + if (_recordingDevice == null) return; + currentStorageFiles = await _getStorageList(_recordingDevice!.id); debugPrint('Storage files: $currentStorageFiles'); - await sendStorage(connectedDevice!.id); + await sendStorage(_recordingDevice!.id); notifyListeners(); } Future startOpenGlass() async { - if (connectedDevice == null) return; - isGlasses = await _hasPhotoStreamingCharacteristic(connectedDevice!.id); + if (_recordingDevice == null) return; + isGlasses = await _hasPhotoStreamingCharacteristic(_recordingDevice!.id); if (!isGlasses) return; - await openGlassProcessing(connectedDevice!, (p) {}, setHasTranscripts); + await openGlassProcessing(_recordingDevice!, (p) {}, setHasTranscripts); _socket?.stop(reason: 'reset state open glass'); notifyListeners(); } @@ -665,20 +664,21 @@ class CaptureProvider extends ChangeNotifier } Future streamDeviceRecording({ - BTDeviceStruct? btDevice, + BTDeviceStruct? device, bool restartBytesProcessing = true, }) async { - debugPrint("streamDeviceRecording ${btDevice} ${restartBytesProcessing}"); - if (btDevice != null) { - _updateConnectedDevice(btDevice); + debugPrint("streamDeviceRecording ${device} ${restartBytesProcessing}"); + if (device != null) { + _updateRecordingDevice(device); } + await _resetState( restartBytesProcessing: restartBytesProcessing, ); } Future stopStreamDeviceRecording() async { - _updateConnectedDevice(null); + _updateRecordingDevice(null); await _resetState(); } diff --git a/app/lib/providers/device_provider.dart b/app/lib/providers/device_provider.dart index 634f13e32..3baa933e2 100644 --- a/app/lib/providers/device_provider.dart +++ b/app/lib/providers/device_provider.dart @@ -205,7 +205,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(device); setIsConnected(true); updateConnectingStatus(false); - await captureProvider?.streamDeviceRecording(restartBytesProcessing: true, btDevice: connectedDevice!); + await captureProvider?.streamDeviceRecording(restartBytesProcessing: true, device: connectedDevice); // initiateBleBatteryListener(); // The device is still disconnected for some reason if (connectedDevice != null) { From 76d0eb793002d9b35bb6586b36461d611fafbacf Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Thu, 19 Sep 2024 00:30:18 -0700 Subject: [PATCH 20/88] speechmatics initial setup as model selection --- app/lib/backend/http/api/memories.dart | 4 + .../backend/schema/transcript_segment.dart | 20 +++ .../memory_detail/compare_transcripts.dart | 21 +++- app/lib/pages/settings/developer.dart | 2 +- backend/utils/stt/streaming.py | 119 +++++++++++++++++- 5 files changed, 162 insertions(+), 4 deletions(-) diff --git a/app/lib/backend/http/api/memories.dart b/app/lib/backend/http/api/memories.dart index d70fcc05b..3ca3eb8b9 100644 --- a/app/lib/backend/http/api/memories.dart +++ b/app/lib/backend/http/api/memories.dart @@ -188,11 +188,13 @@ class TranscriptsResponse { List deepgram; List soniox; List whisperx; + List speechmatics; TranscriptsResponse({ this.deepgram = const [], this.soniox = const [], this.whisperx = const [], + this.speechmatics = const [], }); factory TranscriptsResponse.fromJson(Map json) { @@ -200,6 +202,8 @@ class TranscriptsResponse { deepgram: (json['deepgram'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), soniox: (json['soniox'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), whisperx: (json['whisperx'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), + speechmatics: + (json['speechmatics'] as List).map((segment) => TranscriptSegment.fromJson(segment)).toList(), ); } } diff --git a/app/lib/backend/schema/transcript_segment.dart b/app/lib/backend/schema/transcript_segment.dart index 47f4b1398..3c4eff301 100644 --- a/app/lib/backend/schema/transcript_segment.dart +++ b/app/lib/backend/schema/transcript_segment.dart @@ -129,6 +129,26 @@ class TranscriptSegment { cleanSegments(joinedSimilarSegments); segments.addAll(joinedSimilarSegments); + + // for i, segment in enumerate(segments): + // segments[i].text = ( + // segments[i].text.strip() + // .replace(' ', '') + // .replace(' ,', ',') + // .replace(' .', '.') + // .replace(' ?', '?') + // ) + + // Speechmatics specific issue with punctuation + for (var i = 0; i < segments.length; i++) { + segments[i].text = segments[i] + .text + .replaceAll(' ', '') + .replaceAll(' ,', ',') + .replaceAll(' .', '.') + .replaceAll(' ?', '?') + .trim(); + } } static String segmentsAsString( diff --git a/app/lib/pages/memory_detail/compare_transcripts.dart b/app/lib/pages/memory_detail/compare_transcripts.dart index 913f8a677..86a9e0fb5 100644 --- a/app/lib/pages/memory_detail/compare_transcripts.dart +++ b/app/lib/pages/memory_detail/compare_transcripts.dart @@ -35,7 +35,7 @@ class _CompareTranscriptsPageState extends State { backgroundColor: Theme.of(context).colorScheme.primary, ), body: DefaultTabController( - length: 3, + length: 4, initialIndex: 0, child: Column( children: [ @@ -50,7 +50,12 @@ class _CompareTranscriptsPageState extends State { padding: EdgeInsets.zero, indicatorPadding: EdgeInsets.zero, labelStyle: Theme.of(context).textTheme.titleLarge!.copyWith(fontSize: 18), - tabs: const [Tab(text: 'Deepgram'), Tab(text: 'Soniox'), Tab(text: 'Whisper-x')], + tabs: const [ + Tab(text: 'Deepgram'), + Tab(text: 'Soniox'), + Tab(text: 'SpeechMatics'), + Tab(text: 'Whisper-x'), + ], indicator: BoxDecoration(color: Colors.transparent, borderRadius: BorderRadius.circular(16)), ), Expanded( @@ -84,6 +89,18 @@ class _CompareTranscriptsPageState extends State { ) ], ), + ListView( + shrinkWrap: true, + children: [ + TranscriptWidget( + segments: transcripts?.speechmatics ?? [], + horizontalMargin: false, + topMargin: false, + canDisplaySeconds: true, + isMemoryDetail: true, + ) + ], + ), ListView( shrinkWrap: true, children: [ diff --git a/app/lib/pages/settings/developer.dart b/app/lib/pages/settings/developer.dart index 7fc52e8bb..3f8bafa05 100644 --- a/app/lib/pages/settings/developer.dart +++ b/app/lib/pages/settings/developer.dart @@ -127,7 +127,7 @@ class __DeveloperSettingsPageState extends State<_DeveloperSettingsPage> { underline: Container(height: 0, color: Colors.white), isExpanded: true, itemHeight: 48, - items: ['deepgram', 'soniox'].map>((String value) { + items: ['deepgram', 'soniox', 'speechmatics'].map>((String value) { return DropdownMenuItem( value: value, child: Text( diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 4593a1554..303ff771e 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -257,7 +257,8 @@ async def on_message(): segments[i]['text'] = segments[i]['text'].strip().replace(' ', '') # print('Soniox:', transcript.replace('', '')) - stream_transcript(segments, stream_id) + if segments: + stream_transcript(segments, stream_id) except websockets.exceptions.ConnectionClosedOK: print("Soniox connection closed normally.") except Exception as e: @@ -276,3 +277,119 @@ async def on_message(): except Exception as e: print(f"Exception in process_audio_soniox: {e}") raise # Re-raise the exception to be handled by the caller + + +LANGUAGE = "en" +CONNECTION_URL = f"wss://eu2.rt.speechmatics.com/v2" + + +async def process_audio_speechmatics(stream_transcript, stream_id: int, language: str, uid: str): + # Create a transcription client + api_key = os.getenv('SPEECHMATICS_API_KEY') + uri = 'wss://eu2.rt.speechmatics.com/v2' + # Validate the language and construct the model name + # has_speech_profile = create_user_speech_profile(uid) # only english too + + request = { + "message": "StartRecognition", + "transcription_config": { + "language": language, + "diarization": "speaker", + "operating_point": "enhanced", + "max_delay_mode": "flexible", + "max_delay": 3, + "enable_partials": False, + "enable_entities": True, + "speaker_diarization_config": {"max_speakers": 4} + }, + "audio_format": {"type": "raw", "encoding": "pcm_s16le", "sample_rate": 16000}, + # "audio_events_config": { + # "types": [ + # "laughter", + # "music", + # "applause" + # ] + # } + } + try: + # Connect to Soniox WebSocket + print("Connecting to Speechmatics WebSocket...") + socket = await websockets.connect(uri, extra_headers={"Authorization": f"Bearer {api_key}"}) + print("Connected to Speechmatics WebSocket.") + + # Send the initial request + await socket.send(json.dumps(request)) + print(f"Sent initial request: {request}") + + # Start listening for messages from Soniox + async def on_message(): + try: + async for message in socket: + response = json.loads(message) + if response['message'] == 'AudioAdded': + continue + if response['message'] == 'AddTranscript': + results = response['results'] + if not results: + continue + segments = [] + for r in results: + # print(r) + if not r['alternatives']: + continue + r_data = r['alternatives'][0] + r_type = r['type'] # word | punctuation + r_start = r['start_time'] + r_end = r['end_time'] + + r_content = r_data['content'] + r_confidence = r_data['confidence'] + if r_confidence < 0.4: + print('Low confidence:', r) + continue + r_speaker = r_data['speaker'][1:] if r_data['speaker'] != 'UU' else '1' + speaker = f"SPEAKER_0{r_speaker}" + # print(r_content, r_speaker, [r_start, r_end]) + if not segments: + segments.append({ + 'speaker': speaker, + 'start': r_start, + 'end': r_end, + 'text': r_content, + 'is_user': False, + 'person_id': None, + }) + else: + last_segment = segments[-1] + if last_segment['speaker'] == speaker: + last_segment['text'] += f' {r_content}' + last_segment['end'] += r_end + else: + segments.append({ + 'speaker': speaker, + 'start': r_start, + 'end': r_end, + 'text': r_content, + 'is_user': False, + 'person_id': None, + }) + + if segments: + stream_transcript(segments, stream_id) + # print('---') + else: + print(response) + except websockets.exceptions.ConnectionClosedOK: + print("Speechmatics connection closed normally.") + except Exception as e: + print(f"Error receiving from Speechmatics: {e}") + finally: + if not socket.closed: + await socket.close() + print("Speechmatics WebSocket closed in on_message.") + + asyncio.create_task(on_message()) + return socket + except Exception as e: + print(f"Exception in process_audio_speechmatics: {e}") + raise From 0ecf5c6ed6869a1cefd5059e8ec2e8cfa1ca8976 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Thu, 19 Sep 2024 11:02:52 -0700 Subject: [PATCH 21/88] speechmatics backend setup --- backend/database/memories.py | 2 ++ backend/requirements.txt | 20 +++++++++++++--- backend/routers/transcribe.py | 43 +++++++++++++++++++++++++++++++---- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/backend/database/memories.py b/backend/database/memories.py index d5cf60e25..299cb6eb2 100644 --- a/backend/database/memories.py +++ b/backend/database/memories.py @@ -122,11 +122,13 @@ def get_memory_transcripts_by_model(uid: str, memory_id: str): memory_ref = user_ref.collection('memories').document(memory_id) deepgram_ref = memory_ref.collection('deepgram_streaming') soniox_ref = memory_ref.collection('soniox_streaming') + speechmatics_ref = memory_ref.collection('speechmatics_streaming') whisperx_ref = memory_ref.collection('fal_whisperx') return { 'deepgram': list(sorted([doc.to_dict() for doc in deepgram_ref.stream()], key=lambda x: x['start'])), 'soniox': list(sorted([doc.to_dict() for doc in soniox_ref.stream()], key=lambda x: x['start'])), + 'speechmatics': list(sorted([doc.to_dict() for doc in speechmatics_ref.stream()], key=lambda x: x['start'])), 'whisperx': list(sorted([doc.to_dict() for doc in whisperx_ref.stream()], key=lambda x: x['start'])), } diff --git a/backend/requirements.txt b/backend/requirements.txt index b985bcdbe..8d86dd6e0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,6 +9,7 @@ altair==5.4.0 annotated-types==0.7.0 antlr4-python3-runtime==4.9.3 anyio==4.4.0 +assemblyai==0.33.0 asteroid-filterbanks==0.4.0 asttokens==2.4.1 attrs==24.1.0 @@ -33,11 +34,14 @@ decorator==5.1.1 deepgram-sdk==3.4.0 deprecation==2.1.0 distro==1.9.0 +dnspython==2.6.1 docopt==0.6.2 einops==0.8.0 +email_validator==2.2.0 executing==2.0.1 fal_client==0.4.1 fastapi==0.112.0 +fastapi-cli==0.0.5 fastapi-utilities==0.2.0 filelock==3.15.4 firebase-admin==6.5.0 @@ -60,12 +64,14 @@ googleapis-common-protos==1.63.2 groq==0.9.0 grpcio==1.65.4 grpcio-status==1.62.3 +grpcio-tools==1.62.3 grpclib==0.4.7 h11==0.14.0 h2==4.1.0 hpack==4.0.0 httpcore==1.0.5 httplib2==0.22.0 +httptools==0.6.1 httpx==0.27.0 httpx-sse==0.4.0 huggingface-hub==0.24.5 @@ -76,6 +82,7 @@ idna==3.7 ipython==8.26.0 jedi==0.19.1 Jinja2==3.1.4 +jiwer==3.0.4 joblib==1.4.2 jsonpatch==1.33 jsonpointer==3.0.0 @@ -105,6 +112,7 @@ matplotlib-inline==0.1.7 mdurl==0.1.2 modal==0.64.7 monotonic==1.6 +more-itertools==10.5.0 mplcursors==0.5.3 mpld3==0.5.10 mpmath==1.3.0 @@ -121,6 +129,7 @@ omegaconf==2.3.0 onnxruntime==1.19.0 openai==1.39.0 optuna==3.6.1 +opuslib==3.0.1 orjson==3.10.6 packaging==24.1 pandas==2.2.2 @@ -132,6 +141,7 @@ pinecone-plugin-inference==1.0.3 pinecone-plugin-interface==0.0.7 platformdirs==4.2.2 plotly==5.23.0 +polling2==0.5.0 pooch==1.8.2 portalocker==2.10.1 posthog==3.5.2 @@ -157,6 +167,7 @@ pydub==0.25.1 Pygments==2.18.0 PyJWT==2.9.0 pynndescent==0.5.13 +PyOgg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f234e8a3a346c4874a12509bfa02c4c63a pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 @@ -166,6 +177,7 @@ pytorch-metric-learning==2.6.0 pytz==2024.1 PyYAML==6.0.1 qdrant-client==1.11.0 +rapidfuzz==3.9.7 redis==5.0.8 referencing==0.35.1 regex==2024.7.24 @@ -184,10 +196,13 @@ sigtools==4.0.1 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 +soniox==1.10.1 sortedcontainers==2.4.0 +SoundCard==0.4.3 soundfile==0.12.1 soxr==0.4.0 speechbrain==1.0.0 +speechmatics-python==2.0.1 SQLAlchemy==2.0.32 stack-data==0.6.3 starlette==0.37.2 @@ -195,7 +210,7 @@ streamlit==1.37.1 sympy==1.13.1 synchronicity==0.6.7 tabulate==0.9.0 -tenacity==8.5.0 +tenacity==8.2.3 tensorboardX==2.6.2.2 threadpoolctl==3.5.0 tiktoken==0.7.0 @@ -220,10 +235,9 @@ umap-learn==0.5.6 uritemplate==4.1.1 urllib3==2.2.2 uvicorn==0.30.5 +uvloop==0.20.0 watchdog==4.0.2 watchfiles==0.22.0 wcwidth==0.2.13 websockets==12.0 yarl==1.9.4 -pyogg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f -opuslib==3.0.1 diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 43fa826e1..a188f1aca 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -15,8 +15,8 @@ from models.memory import Memory, TranscriptSegment from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated from models.processing_memory import ProcessingMemory -from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util from utils.audio import create_wav_from_bytes, merge_wav_files +from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util from utils.memories.process_memory import process_memory from utils.other.storage import upload_postprocessing_audio from utils.processing_memories import create_memory_by_processing_memory @@ -95,12 +95,31 @@ def _combine_segments(segments: [], new_segments: [], delta_seconds: int = 0): segments.extend(joined_similar_segments) + # Speechmatics specific issue with punctuation + for i, segment in enumerate(segments): + segments[i].text = ( + segments[i].text.strip() + .replace(' ', '') + .replace(' ,', ',') + .replace(' .', '.') + .replace(' ?', '?') + ) return segments class STTService(str, Enum): deepgram = "deepgram" soniox = "soniox" + speechmatics = "speechmatics" + + @staticmethod + def get_model_name(value): + if value == STTService.deepgram: + return 'deepgram_streaming' + elif value == STTService.soniox: + return 'soniox_streaming' + elif value == STTService.speechmatics: + return 'speechmatics_streaming' async def _websocket_util( @@ -114,6 +133,9 @@ async def _websocket_util( sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages): stt_service = STTService.deepgram + if stt_service == STTService.speechmatics and (sample_rate != 16000 or codec != 'opus'): + stt_service = STTService.deepgram + # Check: Why do we need try-catch around websocket.accept? try: await websocket.accept() @@ -182,6 +204,7 @@ def stream_audio(audio_buffer): processing_audio_frames.append(audio_buffer) soniox_socket = None + speechmatics_socket = None deepgram_socket = None deepgram_socket2 = None @@ -216,6 +239,10 @@ def stream_audio(audio_buffer): soniox_socket = await process_audio_soniox( stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None ) + elif stt_service == STTService.speechmatics: + speechmatics_socket = await process_audio_speechmatics( + stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None + ) except Exception as e: print(f"Initial processing error: {e}") @@ -228,7 +255,7 @@ def stream_audio(audio_buffer): decoder = opuslib.Decoder(sample_rate, channels) - async def receive_audio(dg_socket1, dg_socket2, soniox_socket): + async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_socket): nonlocal websocket_active nonlocal websocket_close_code nonlocal timer_start @@ -252,6 +279,10 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket): decoded_opus = decoder.decode(bytes(data), frame_size=160) await soniox_socket.send(decoded_opus) + if speechmatics_socket is not None: + decoded_opus = decoder.decode(bytes(data), frame_size=160) + await speechmatics_socket.send(decoded_opus) + if deepgram_socket is not None: elapsed_seconds = time.time() - timer_start if elapsed_seconds > duration or not dg_socket2: @@ -281,6 +312,8 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket): dg_socket2.finish() if soniox_socket: await soniox_socket.close() + if speechmatics_socket: + await speechmatics_socket.close() # heart beat async def send_heartbeat(): @@ -410,8 +443,7 @@ async def _post_process_memory(memory: Memory): # Process emotional_feedback = processing_memory.emotional_feedback status, new_memory = postprocess_memory_util( - memory.id, file_path, uid, emotional_feedback, - 'deepgram_streaming' if stt_service == STTService.deepgram else 'soniox_streaming' + memory.id, file_path, uid, emotional_feedback, STTService.get_model_name(stt_service) ) if status == 200: memory = new_memory @@ -587,7 +619,8 @@ async def _try_flush_new_memory(time_validate: bool = True): processing_memory = None try: - receive_task = asyncio.create_task(receive_audio(deepgram_socket, deepgram_socket2, soniox_socket)) + receive_task = asyncio.create_task( + receive_audio(deepgram_socket, deepgram_socket2, soniox_socket, speechmatics_socket)) heartbeat_task = asyncio.create_task(send_heartbeat()) # Run task From 22fbf89dc7645224a0b30dad01d74ab06e32dca1 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Thu, 19 Sep 2024 23:35:02 +0530 Subject: [PATCH 22/88] clear chat func in backend --- backend/database/chat.py | 17 +++++++++++++++++ backend/routers/chat.py | 7 +++++++ 2 files changed, 24 insertions(+) diff --git a/backend/database/chat.py b/backend/database/chat.py index 3c4c13329..413604e71 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -82,3 +82,20 @@ def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: b ] return messages + + +def clear_chat(uid,batch_size): + user_ref = db.collection('users').document(uid) + messages_ref = user_ref.collection('messages') + if batch_size == 0: + return + docs = messages_ref.list_documents(page_size=batch_size) + deleted = 0 + + for doc in docs: + print(f"Deleting doc {doc.id} => {doc.get().to_dict()}") + doc.delete() + deleted = deleted + 1 + + if deleted >= batch_size: + return clear_chat(uid,batch_size) \ No newline at end of file diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 7b427fb36..3d9b6efde 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -55,6 +55,13 @@ def send_message( ai_message.memories = memories if len(memories) < 5 else memories[:5] return ai_message +@router.delete('/v1/clear-chat', tags=['chat'], response_model=Message) +def clear_chat(uid: str = Depends(auth.get_current_user_uid)): + + chat_db.clear_chat(uid, 400) + return initial_message_util(uid) + + def initial_message_util(uid: str, plugin_id: Optional[str] = None): plugin = get_plugin_by_id(plugin_id) From a280eada630cd9f54324d3ef20103421e8be5289 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Thu, 19 Sep 2024 23:36:00 +0530 Subject: [PATCH 23/88] clear chat func on frontend --- app/lib/backend/http/api/messages.dart | 10 ++++++++++ app/lib/providers/message_provider.dart | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/app/lib/backend/http/api/messages.dart b/app/lib/backend/http/api/messages.dart index ea33c8362..5eb6f2747 100644 --- a/app/lib/backend/http/api/messages.dart +++ b/app/lib/backend/http/api/messages.dart @@ -23,6 +23,16 @@ Future> getMessagesServer() async { return []; } +Future> clearChatServer() async { + var response = await makeApiCall(url: '${Env.apiBaseUrl}v1/clear-chat', headers: {}, method: 'DELETE', body: ''); + if (response == null) throw Exception('Failed to delete chat'); + if (response.statusCode == 200) { + return [ServerMessage.fromJson(jsonDecode(response.body))]; + } else { + throw Exception('Failed to delete chat'); + } +} + Future sendMessageServer(String text, {String? pluginId}) { return makeApiCall( url: '${Env.apiBaseUrl}v1/messages?plugin_id=$pluginId', diff --git a/app/lib/providers/message_provider.dart b/app/lib/providers/message_provider.dart index 7e747cafc..e12952ab4 100644 --- a/app/lib/providers/message_provider.dart +++ b/app/lib/providers/message_provider.dart @@ -41,6 +41,12 @@ class MessageProvider extends ChangeNotifier { return messages; } + Future clearChat() async { + var mes = await clearChatServer(); + messages = mes; + notifyListeners(); + } + void addMessage(ServerMessage message) { messages.insert(0, message); notifyListeners(); From 13b929fdefd184edb3e20969ad4e78acd1bdd68e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Fri, 20 Sep 2024 06:17:28 +0700 Subject: [PATCH 24/88] Integrate speech profile provider with pure socket --- app/lib/main.dart | 15 +-- app/lib/pages/capture/widgets/widgets.dart | 2 +- app/lib/pages/home/page.dart | 4 +- app/lib/pages/memories/widgets/capture.dart | 2 - .../memories/widgets/processing_capture.dart | 2 - .../onboarding/speech_profile_widget.dart | 2 +- app/lib/pages/speech_profile/page.dart | 2 +- .../providers/speech_profile_provider.dart | 126 ++++++++++-------- app/lib/providers/websocket_provider.dart | 1 + app/lib/services/sockets.dart | 54 ++++---- 10 files changed, 115 insertions(+), 95 deletions(-) diff --git a/app/lib/main.dart b/app/lib/main.dart index 4e0069187..c8b6beb51 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -164,15 +164,14 @@ class _MyAppState extends State with WidgetsBindingObserver { update: (BuildContext context, value, MessageProvider? previous) => (previous?..updatePluginProvider(value)) ?? MessageProvider(), ), - ChangeNotifierProvider(create: (context) => WebSocketProvider()), - ChangeNotifierProxyProvider3( + ChangeNotifierProxyProvider2( create: (context) => CaptureProvider(), - update: (BuildContext context, memory, message, wsProvider, CaptureProvider? previous) => + update: (BuildContext context, memory, message, CaptureProvider? previous) => (previous?..updateProviderInstances(memory, message)) ?? CaptureProvider(), ), - ChangeNotifierProxyProvider2( + ChangeNotifierProxyProvider( create: (context) => DeviceProvider(), - update: (BuildContext context, captureProvider, wsProvider, DeviceProvider? previous) => + update: (BuildContext context, captureProvider, DeviceProvider? previous) => (previous?..setProviders(captureProvider)) ?? DeviceProvider(), ), ChangeNotifierProxyProvider( @@ -181,10 +180,10 @@ class _MyAppState extends State with WidgetsBindingObserver { (previous?..setDeviceProvider(value)) ?? OnboardingProvider(), ), ListenableProvider(create: (context) => HomeProvider()), - ChangeNotifierProxyProvider3( + ChangeNotifierProxyProvider2( create: (context) => SpeechProfileProvider(), - update: (BuildContext context, device, capture, wsProvider, SpeechProfileProvider? previous) => - (previous?..setProviders(device, capture, wsProvider)) ?? SpeechProfileProvider(), + update: (BuildContext context, device, capture, SpeechProfileProvider? previous) => + (previous?..setProviders(device, capture)) ?? SpeechProfileProvider(), ), ChangeNotifierProxyProvider2( create: (context) => MemoryDetailProvider(), diff --git a/app/lib/pages/capture/widgets/widgets.dart b/app/lib/pages/capture/widgets/widgets.dart index 55f0f162c..cd50d96fe 100644 --- a/app/lib/pages/capture/widgets/widgets.dart +++ b/app/lib/pages/capture/widgets/widgets.dart @@ -228,7 +228,7 @@ class SpeechProfileCardWidget extends StatelessWidget { if (hasSpeakerProfile != SharedPreferencesUtil().hasSpeakerProfile) { if (context.mounted) { // TODO: is the websocket restarting once the user comes back? - // TODO: thinh, socket speech profile + // TODO: thinh, socket change settings // context.read().restartWebSocket(); } } diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index f92ff0f28..361d6ec94 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -13,6 +13,7 @@ import 'package:friend_private/pages/home/device.dart'; import 'package:friend_private/pages/memories/page.dart'; import 'package:friend_private/pages/plugins/page.dart'; import 'package:friend_private/pages/settings/page.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; import 'package:friend_private/providers/home_provider.dart'; @@ -21,6 +22,7 @@ import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/providers/message_provider.dart'; import 'package:friend_private/providers/plugin_provider.dart'; import 'package:friend_private/services/notification_service.dart'; +import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; import 'package:friend_private/utils/audio/foreground.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -544,7 +546,7 @@ class _HomePageState extends State with WidgetsBindingObserver, Ticker if (language != SharedPreferencesUtil().recordingsLanguage || hasSpeech != SharedPreferencesUtil().hasSpeakerProfile || transcriptModel != SharedPreferencesUtil().transcriptionModel) { - // TODO: thinh, socket speech profile + // TODO: thinh, socket change settings // context.read().restartWebSocket(); } }, diff --git a/app/lib/pages/memories/widgets/capture.dart b/app/lib/pages/memories/widgets/capture.dart index f81497050..76196a8ef 100644 --- a/app/lib/pages/memories/widgets/capture.dart +++ b/app/lib/pages/memories/widgets/capture.dart @@ -75,7 +75,6 @@ class LiteCaptureWidgetState extends State @override void dispose() { WidgetsBinding.instance.removeObserver(this); - // context.read().closeWebSocket(); super.dispose(); } @@ -138,7 +137,6 @@ class LiteCaptureWidgetState extends State builder: (c) => getDialog( context, () async { - context.read().closeWebSocketWithoutReconnect('Firmware change detected'); var connectedDevice = deviceProvider.connectedDevice; var codec = await _getAudioCodec(connectedDevice!.id); await context.read().changeAudioRecordProfile(codec); diff --git a/app/lib/pages/memories/widgets/processing_capture.dart b/app/lib/pages/memories/widgets/processing_capture.dart index 3142af71f..68a9a9a83 100644 --- a/app/lib/pages/memories/widgets/processing_capture.dart +++ b/app/lib/pages/memories/widgets/processing_capture.dart @@ -94,8 +94,6 @@ class _MemoryCaptureWidgetState extends State { () async { Navigator.pop(context); provider.updateRecordingState(RecordingState.initialising); - // TODO: thinh, socket check why we need to close socket provider here, disable temporary - //context.read().closeWebSocketWithoutReconnect('Recording with phone mic'); await provider.changeAudioRecordProfile(BleAudioCodec.pcm16, 16000); await provider.streamRecording(); MixpanelManager().phoneMicRecordingStarted(); diff --git a/app/lib/pages/onboarding/speech_profile_widget.dart b/app/lib/pages/onboarding/speech_profile_widget.dart index acbb51f98..01712aeac 100644 --- a/app/lib/pages/onboarding/speech_profile_widget.dart +++ b/app/lib/pages/onboarding/speech_profile_widget.dart @@ -207,7 +207,7 @@ class _SpeechProfileWidgetState extends State with TickerPr await provider.initialise(true); provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () async { - provider.finalize(true); + provider.finalize(); }); provider.updateStartedRecording(true); }, diff --git a/app/lib/pages/speech_profile/page.dart b/app/lib/pages/speech_profile/page.dart index 99a66ef9a..bf46896d1 100644 --- a/app/lib/pages/speech_profile/page.dart +++ b/app/lib/pages/speech_profile/page.dart @@ -325,7 +325,7 @@ class _SpeechProfilePageState extends State with TickerProvid // 1.5 minutes seems reasonable provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () { - provider.finalize(false); + provider.finalize(); }); provider.updateStartedRecording(true); }, diff --git a/app/lib/providers/speech_profile_provider.dart b/app/lib/providers/speech_profile_provider.dart index ecacc5946..a77191bd8 100644 --- a/app/lib/providers/speech_profile_provider.dart +++ b/app/lib/providers/speech_profile_provider.dart @@ -10,22 +10,23 @@ import 'package:friend_private/backend/http/cloud_storage.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/backend/schema/memory.dart'; +import 'package:friend_private/backend/schema/message_event.dart'; import 'package:friend_private/backend/schema/structured.dart'; import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; -import 'package:friend_private/providers/websocket_provider.dart'; import 'package:friend_private/services/devices.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/audio/wav_bytes.dart'; import 'package:friend_private/utils/memories/process.dart'; -import 'package:friend_private/utils/websockets.dart'; +import 'package:friend_private/utils/pure_socket.dart'; import 'package:uuid/uuid.dart'; -class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin implements IDeviceServiceSubsciption { +class SpeechProfileProvider extends ChangeNotifier + with MessageNotifierMixin + implements IDeviceServiceSubsciption, ITransctipSegmentSocketServiceListener { DeviceProvider? deviceProvider; CaptureProvider? captureProvider; - WebSocketProvider? webSocketProvider; bool? permissionEnabled; bool loading = false; BTDeviceStruct? device; @@ -38,6 +39,8 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp WavBytesUtil audioStorage = WavBytesUtil(codec: BleAudioCodec.opus); StreamSubscription? _bleBytesStream; + TranscripSegmentSocketService? _socket; + bool startedRecording = false; double percentageCompleted = 0; bool uploadingProfile = false; @@ -50,6 +53,8 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp String text = ''; String message = ''; + late bool _isFromOnboarding; + /// only used during onboarding ///// String loadingText = 'Uploading your voice profile....'; ServerMemory? memory; @@ -71,10 +76,9 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp notifyListeners(); } - void setProviders(DeviceProvider provider, CaptureProvider captureProvider, WebSocketProvider wsProvider) { + void setProviders(DeviceProvider provider, CaptureProvider captureProvider) { deviceProvider = provider; this.captureProvider = captureProvider; - webSocketProvider = wsProvider; notifyListeners(); } @@ -87,14 +91,15 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp } Future initialise(bool isFromOnboarding) async { + _isFromOnboarding = isFromOnboarding; setInitialising(true); device = deviceProvider?.connectedDevice; await captureProvider!.resetForSpeechProfile(); - await initiateWebsocket(isFromOnboarding); + await _initiateWebsocket(force: true); // _bleBytesStream = captureProvider?.bleBytesStream; if (device != null) await initiateFriendAudioStreaming(); - if (webSocketProvider?.wsConnectionState != WebsocketConnectionStatus.connected) { + if (_socket?.state != SocketServiceState.connected) { // wait for websocket to connect await Future.delayed(Duration(seconds: 2)); } @@ -120,57 +125,34 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp ServiceManager.instance().device.subscribe(this, this); } - Future initiateWebsocket(bool isFromOnboarding) async { - await webSocketProvider?.initWebSocket( - codec: BleAudioCodec.opus, - sampleRate: 16000, - includeSpeechProfile: false, - newMemoryWatch: false, - onConnectionSuccess: () { - print('Websocket connected in speech profile'); - notifyListeners(); - }, - onConnectionFailed: (err) { - notifyError('WS_ERR'); - }, - onConnectionClosed: (int? closeCode, String? closeReason) {}, - onConnectionError: (err) { - notifyError('WS_ERR'); - }, - onMessageReceived: (List newSegments) { - if (newSegments.isEmpty) return; - if (segments.isEmpty) { - audioStorage.removeFramesRange(fromSecond: 0, toSecond: newSegments[0].start.toInt()); - } - streamStartedAtSecond ??= newSegments[0].start; - - TranscriptSegment.combineSegments( - segments, - newSegments, - toRemoveSeconds: streamStartedAtSecond ?? 0, - ); - updateProgressMessage(); - _validateSingleSpeaker(); - _handleCompletion(isFromOnboarding); - notifyInfo('SCROLL_DOWN'); - debugPrint('Memory creation timer restarted'); - }, - ); + Future _initiateWebsocket({bool force = false}) async { + _socket = await ServiceManager.instance() + .socket + .speechProfile(codec: BleAudioCodec.opus, sampleRate: 16000, force: force); + if (_socket == null) { + throw Exception("Can not create new speech profile socket"); + } + + // Ok + _socket?.subscribe(this, this); + print('Websocket connected in speech profile'); + // TODO: thinh, socket ? + //notifyListeners(); } - _handleCompletion(bool isFromOnboarding) async { + _handleCompletion() async { if (uploadingProfile || profileCompleted) return; String text = segments.map((e) => e.text).join(' ').trim(); int wordsCount = text.split(' ').length; percentageCompleted = (wordsCount / targetWordsCount).clamp(0, 1); notifyListeners(); if (percentageCompleted == 1) { - await finalize(isFromOnboarding); + await finalize(); } notifyListeners(); } - Future finalize(bool isFromOnboarding) async { + Future finalize() async { if (uploadingProfile || profileCompleted) return; int duration = segments.isEmpty ? 0 : segments.last.end.toInt(); @@ -185,7 +167,7 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp } uploadingProfile = true; notifyListeners(); - await webSocketProvider?.closeWebSocketWithoutReconnect('finalizing'); + await _socket?.stop(reason: 'finalizing'); forceCompletionTimer?.cancel(); connectionStateListener?.cancel(); _bleBytesStream?.cancel(); @@ -200,7 +182,7 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp updateLoadingText('Personalizing your experience...'); SharedPreferencesUtil().hasSpeakerProfile = true; - if (isFromOnboarding) { + if (_isFromOnboarding) { await createMemory(); captureProvider?.clearTranscripts(); } @@ -230,9 +212,11 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp onAudioBytesReceived: (List value) { if (value.isEmpty) return; audioStorage.storeFramePacket(value); + value.removeRange(0, 3); - if (webSocketProvider?.wsConnectionState == WebsocketConnectionStatus.connected) { - webSocketProvider?.websocketChannel?.sink.add(value); + if (_socket?.state == SocketServiceState.connected) { + debugPrint("Socket stream ${value.length}"); + _socket?.send(value); } }, ); @@ -296,9 +280,7 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp percentageCompleted = 0; uploadingProfile = false; profileCompleted = false; - await webSocketProvider?.closeWebSocketWithoutReconnect('closing'); - // TODO: thinh, socket check why? disable temporary - // await captureProvider?.resetState(restartBytesProcessing: true, isFromSpeechProfile: true); + await _socket?.stop(reason: 'closing'); notifyListeners(); } @@ -356,6 +338,7 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp connectionStateListener?.cancel(); _bleBytesStream?.cancel(); forceCompletionTimer?.cancel(); + _socket?.unsubscribe(this); ServiceManager.instance().device.unsubscribe(this); super.dispose(); } @@ -387,4 +370,39 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin imp @override void onStatusChanged(DeviceServiceStatus status) {} + + @override + void onClosed() { + // TODO: implement onClosed + } + + @override + void onError(Object err) { + notifyError('WS_ERR'); + } + + @override + void onMessageEventReceived(ServerMessageEvent event) { + // TODO: implement onMessageEventReceived + } + + @override + void onSegmentReceived(List newSegments) { + if (newSegments.isEmpty) return; + if (segments.isEmpty) { + audioStorage.removeFramesRange(fromSecond: 0, toSecond: newSegments[0].start.toInt()); + } + streamStartedAtSecond ??= newSegments[0].start; + + TranscriptSegment.combineSegments( + segments, + newSegments, + toRemoveSeconds: streamStartedAtSecond ?? 0, + ); + updateProgressMessage(); + _validateSingleSpeaker(); + _handleCompletion(); + notifyInfo('SCROLL_DOWN'); + debugPrint('Memory creation timer restarted'); + } } diff --git a/app/lib/providers/websocket_provider.dart b/app/lib/providers/websocket_provider.dart index 1d138b34c..b316956ed 100644 --- a/app/lib/providers/websocket_provider.dart +++ b/app/lib/providers/websocket_provider.dart @@ -10,6 +10,7 @@ import 'package:friend_private/utils/websockets.dart'; import 'package:internet_connection_checker_plus/internet_connection_checker_plus.dart'; import 'package:web_socket_channel/io.dart'; +@Deprecated("Use the socket service") class WebSocketProvider with ChangeNotifier { WebsocketConnectionStatus wsConnectionState = WebsocketConnectionStatus.notConnected; bool websocketReconnecting = false; diff --git a/app/lib/services/sockets.dart b/app/lib/services/sockets.dart index cf6fa7510..4cf3e0a34 100644 --- a/app/lib/services/sockets.dart +++ b/app/lib/services/sockets.dart @@ -8,14 +8,14 @@ abstract class ISocketService { Future memory( {required BleAudioCodec codec, required int sampleRate, bool force = false}); - TranscripSegmentSocketService speechProfile(); + Future speechProfile( + {required BleAudioCodec codec, required int sampleRate, bool force = false}); } abstract interface class ISocketServiceSubsciption {} class SocketServicePool extends ISocketService { - TranscripSegmentSocketService? _memory; - TranscripSegmentSocketService? _speechProfile; + TranscripSegmentSocketService? _socket; @override void start() { @@ -24,50 +24,54 @@ class SocketServicePool extends ISocketService { @override void stop() async { - await _memory?.stop(); - await _speechProfile?.stop(); + await _socket?.stop(); } // Warn: Should use a better solution to prevent race conditions - bool memoryMutex = false; - @override - Future memory( + bool mutex = false; + Future socket( {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { - while (memoryMutex) { + while (mutex) { await Future.delayed(const Duration(milliseconds: 50)); } - memoryMutex = true; - - debugPrint("socket memory > $codec $sampleRate $force"); + mutex = true; try { if (!force && - _memory?.codec == codec && - _memory?.sampleRate == sampleRate && - _memory?.state == SocketServiceState.connected) { - return _memory; + _socket?.codec == codec && + _socket?.sampleRate == sampleRate && + _socket?.state == SocketServiceState.connected) { + return _socket; } // new socket - await _memory?.stop(); + await _socket?.stop(); - _memory = MemoryTranscripSegmentSocketService.create(sampleRate, codec); - await _memory?.start(); - if (_memory?.state != SocketServiceState.connected) { + _socket = MemoryTranscripSegmentSocketService.create(sampleRate, codec); + await _socket?.start(); + if (_socket?.state != SocketServiceState.connected) { return null; } - return _memory; + return _socket; } finally { - memoryMutex = false; + mutex = false; } return null; } @override - TranscripSegmentSocketService speechProfile() { - // TODO: implement speechProfile - throw UnimplementedError(); + Future memory( + {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { + debugPrint("socket memory > $codec $sampleRate $force"); + return await socket(codec: codec, sampleRate: sampleRate, force: force); + } + + @override + Future speechProfile( + {required BleAudioCodec codec, required int sampleRate, bool force = false}) async { + debugPrint("socket speech profile > $codec $sampleRate $force"); + return await socket(codec: codec, sampleRate: sampleRate, force: force); } } From 5f89dab4ec798a33894f9d1069b01371632adf05 Mon Sep 17 00:00:00 2001 From: Nik Shevchenko <43514161+kodjima33@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:51:02 -0700 Subject: [PATCH 25/88] Update Flash_device.md --- docs/_get_started/Flash_device.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/_get_started/Flash_device.md b/docs/_get_started/Flash_device.md index 2f9bc4ba9..b2ee810ce 100644 --- a/docs/_get_started/Flash_device.md +++ b/docs/_get_started/Flash_device.md @@ -1,6 +1,6 @@ --- layout: default -title: Flashing FRIEND Firmware +title: Update FRIEND Firmware nav_order: 3 --- # Video Tutorial @@ -13,9 +13,13 @@ This guide will walk you through the process of flashing the latest firmware ont ## Downloading the Firmware -1. Go to the [FRIEND GitHub repository](https://github.com/BasedHardware/Omi) and navigate to the "Devices > FRIEND > firmware" section. +1. Go to the [FRIEND GitHub repository](https://github.com/BasedHardware/Omi) and navigate to the " FRIEND > firmware" section. 2. Find the latest firmware release and bootloader, then download the corresponding `.uf2` files. +Or download these files + - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/bootloader/bootloader0.9.0.uf2) + - **Firmware:** [firmware1.0.4.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/firmware1.0.4.uf2) + ## Putting FRIEND into DFU Mode 1. **Locate the DFU Button:** Find the small pin-sized button on the FRIEND device's circuit board (refer to the image below if needed). @@ -45,4 +49,4 @@ You have successfully flashed the latest firmware onto your FRIEND device. You c Once you've installed the app, follow the in-app instructions to connect your FRIEND device and start exploring its features. -i just added this video to the repo docs/images/updating_your_friend.mov add it to this \ No newline at end of file +i just added this video to the repo docs/images/updating_your_friend.mov add it to this From 2b33e86140fe181ae07d7af06ca4ded51d242244 Mon Sep 17 00:00:00 2001 From: Nik Shevchenko <43514161+kodjima33@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:52:20 -0700 Subject: [PATCH 26/88] Create install_firmware.md --- docs/_assembly/install_firmware.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/_assembly/install_firmware.md diff --git a/docs/_assembly/install_firmware.md b/docs/_assembly/install_firmware.md new file mode 100644 index 000000000..fd7a3f0d8 --- /dev/null +++ b/docs/_assembly/install_firmware.md @@ -0,0 +1 @@ +We've moved! please navigate [here](https://docs.omi.me/get_started/Flash_device/) From 5919da03af8b382efa4e712755569bdb14ab4b85 Mon Sep 17 00:00:00 2001 From: Nik Shevchenko <43514161+kodjima33@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:58:17 -0700 Subject: [PATCH 27/88] Update install_firmware.md --- docs/_assembly/install_firmware.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/_assembly/install_firmware.md b/docs/_assembly/install_firmware.md index fd7a3f0d8..3ac6a266b 100644 --- a/docs/_assembly/install_firmware.md +++ b/docs/_assembly/install_firmware.md @@ -1 +1,7 @@ +--- +layout: default +title: Install firmware (old) +nav_order: 3 +--- + We've moved! please navigate [here](https://docs.omi.me/get_started/Flash_device/) From c698c3cf4093eee23f729445bf5aa54c98c3d0a9 Mon Sep 17 00:00:00 2001 From: Nik Shevchenko <43514161+kodjima33@users.noreply.github.com> Date: Thu, 19 Sep 2024 19:05:18 -0700 Subject: [PATCH 28/88] Rename install_firmware.md to Install_firmware.md --- docs/_assembly/{install_firmware.md => Install_firmware.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/_assembly/{install_firmware.md => Install_firmware.md} (100%) diff --git a/docs/_assembly/install_firmware.md b/docs/_assembly/Install_firmware.md similarity index 100% rename from docs/_assembly/install_firmware.md rename to docs/_assembly/Install_firmware.md From f8c612f1a299cc226d2f348a3a570f325582db05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Fri, 20 Sep 2024 09:05:29 +0700 Subject: [PATCH 29/88] Notify record profile changes --- app/lib/pages/capture/widgets/widgets.dart | 5 ++--- app/lib/pages/home/page.dart | 5 +++-- app/lib/providers/capture_provider.dart | 10 ++++++---- app/lib/providers/speech_profile_provider.dart | 5 ----- app/lib/utils/pure_socket.dart | 6 +++--- 5 files changed, 14 insertions(+), 17 deletions(-) diff --git a/app/lib/pages/capture/widgets/widgets.dart b/app/lib/pages/capture/widgets/widgets.dart index cd50d96fe..d131f5ae7 100644 --- a/app/lib/pages/capture/widgets/widgets.dart +++ b/app/lib/pages/capture/widgets/widgets.dart @@ -4,6 +4,7 @@ import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/backend/schema/transcript_segment.dart'; import 'package:friend_private/pages/capture/connect.dart'; import 'package:friend_private/pages/speech_profile/page.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/device_provider.dart'; import 'package:friend_private/providers/home_provider.dart'; @@ -227,9 +228,7 @@ class SpeechProfileCardWidget extends StatelessWidget { await routeToPage(context, const SpeechProfilePage()); if (hasSpeakerProfile != SharedPreferencesUtil().hasSpeakerProfile) { if (context.mounted) { - // TODO: is the websocket restarting once the user comes back? - // TODO: thinh, socket change settings - // context.read().restartWebSocket(); + context.read().onRecordProfileSettingChanged(); } } }, diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index 361d6ec94..2afa8183f 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -546,8 +546,9 @@ class _HomePageState extends State with WidgetsBindingObserver, Ticker if (language != SharedPreferencesUtil().recordingsLanguage || hasSpeech != SharedPreferencesUtil().hasSpeakerProfile || transcriptModel != SharedPreferencesUtil().transcriptionModel) { - // TODO: thinh, socket change settings - // context.read().restartWebSocket(); + if (context.mounted) { + context.read().onRecordProfileSettingChanged(); + } } }, ), diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 29ec61883..3a9fcc548 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -332,6 +332,10 @@ class CaptureProvider extends ChangeNotifier } } + Future onRecordProfileSettingChanged() async { + await _resetState(restartBytesProcessing: true); + } + Future changeAudioRecordProfile([ BleAudioCodec? audioCodec, int? sampleRate, @@ -353,14 +357,13 @@ class CaptureProvider extends ChangeNotifier debugPrint('is ws null: ${_socket == null}'); - // TODO: thinh, socket + // Get memory socket _socket = await ServiceManager.instance().socket.memory(codec: codec, sampleRate: sampleRate, force: force); if (_socket == null) { throw Exception("Can not create new memory socket"); } - - // Ok _socket?.subscribe(this, this); + if (segments.isNotEmpty) { // means that it was a reconnection, so we need to reset streamStartedAtSecond = null; @@ -458,7 +461,6 @@ class CaptureProvider extends ChangeNotifier Future resetForSpeechProfile() async { closeBleStream(); - // TODO: thinh, socket, check why we need reset for speech profile here await _socket?.stop(reason: 'reset for speech profile'); setAudioBytesConnected(false); notifyListeners(); diff --git a/app/lib/providers/speech_profile_provider.dart b/app/lib/providers/speech_profile_provider.dart index a77191bd8..51c894b91 100644 --- a/app/lib/providers/speech_profile_provider.dart +++ b/app/lib/providers/speech_profile_provider.dart @@ -132,12 +132,7 @@ class SpeechProfileProvider extends ChangeNotifier if (_socket == null) { throw Exception("Can not create new speech profile socket"); } - - // Ok _socket?.subscribe(this, this); - print('Websocket connected in speech profile'); - // TODO: thinh, socket ? - //notifyListeners(); } _handleCompletion() async { diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 461f4c74c..941ef189a 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -147,7 +147,7 @@ class PureSocket implements IPureSocket { @override Future disconnect() async { _status = PureSocketStatus.disconnected; - onClosed(); + onClosed(); await _cleanUp(); } @@ -283,8 +283,8 @@ class TranscripSegmentSocketService implements IPureSocketListener { this.newMemoryWatch = true, }) { final recordingsLanguage = SharedPreferencesUtil().recordingsLanguage; - var params = - '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}&include_speech_profile=$includeSpeechProfile&new_memory_watch=$newMemoryWatch'; + var params = '?language=$recordingsLanguage&sample_rate=$sampleRate&codec=$codec&uid=${SharedPreferencesUtil().uid}' + '&include_speech_profile=$includeSpeechProfile&new_memory_watch=$newMemoryWatch&stt_service=${SharedPreferencesUtil().transcriptionModel}'; String url = '${Env.apiBaseUrl!.replaceAll('https', 'wss')}listen$params'; _socket = PureSocket(url); From eed8a6b420823eda553976d096b3978a0f2f15f6 Mon Sep 17 00:00:00 2001 From: Nik Shevchenko <43514161+kodjima33@users.noreply.github.com> Date: Thu, 19 Sep 2024 19:11:32 -0700 Subject: [PATCH 30/88] Update Compile_firmware.md --- docs/_assembly/Compile_firmware.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/_assembly/Compile_firmware.md b/docs/_assembly/Compile_firmware.md index bd6d7278b..3677e208f 100644 --- a/docs/_assembly/Compile_firmware.md +++ b/docs/_assembly/Compile_firmware.md @@ -11,11 +11,10 @@ Important: If you purchased an assembled device please skip this step If you purchased an unassembled Friend device or built it yourself using our hardware guide, follow the steps below to flash the firmware: -### Official firmware: +### Want to install a pre-built firmware? Navigate [here](https://docs.omi.me/get_started/Flash_device/) -Go to [Releases](https://github.com/BasedHardware/Omi/releases) in Github, and find the latest official firmware release. To use this firmware, simply download it and skip to step 6. If you would like to build the firmware yourself please follow all the steps below. -### Build your own firmware: +## Build your own firmware: 1. Set up nRF Connect by following the tutorial in this video: [https://youtu.be/EAJdOqsL9m8](https://youtu.be/EAJdOqsL9m8) From d137b8a1264ebfcc6dec07e088df24fba3408a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Fri, 20 Sep 2024 10:04:13 +0700 Subject: [PATCH 31/88] Notify user when internet connection lost --- app/lib/utils/pure_socket.dart | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 941ef189a..3c357e356 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -147,14 +147,21 @@ class PureSocket implements IPureSocket { @override Future disconnect() async { _status = PureSocketStatus.disconnected; + if (_status == PureSocketStatus.connected) { + // Warn: should not use await cause dead end by socket closed. + _channel?.sink.close(socket_channel_status.normalClosure); + } onClosed(); - await _cleanUp(); } Future _cleanUp() async { _internetLostDelayTimer?.cancel(); _internetStatusListener?.cancel(); - await _channel?.sink.close(socket_channel_status.normalClosure); + } + + Future stop() async { + await disconnect(); + await _cleanUp(); } @override @@ -187,6 +194,7 @@ class PureSocket implements IPureSocket { } void _reconnect() async { + debugPrint("[Socket] reconnect...${_retries+1}..."); const int initialBackoffTimeMs = 1000; // 1 second const double multiplier = 1.5; const int maxRetries = 7; @@ -217,6 +225,7 @@ class PureSocket implements IPureSocket { @override void onInternetSatusChanged(InternetStatus status) { + debugPrint("[Socket] Internet connection changed $status"); _internetStatus = status; switch (status) { case InternetStatus.connected: @@ -308,7 +317,7 @@ class TranscripSegmentSocketService implements IPureSocketListener { } Future stop({String? reason}) async { - await _socket.disconnect(); + await _socket.stop(); _listeners.clear(); if (reason != null) { From 21e038ccad326b491e0fa74215daf3c653c70303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Fri, 20 Sep 2024 11:04:29 +0700 Subject: [PATCH 32/88] Move capture provider away from speech profile provider --- app/lib/main.dart | 6 ++--- app/lib/pages/speech_profile/page.dart | 23 +++++++++++++++++-- app/lib/providers/capture_provider.dart | 10 +++++--- app/lib/providers/device_provider.dart | 2 +- .../providers/speech_profile_provider.dart | 11 +++------ app/lib/utils/pure_socket.dart | 4 ++-- 6 files changed, 37 insertions(+), 19 deletions(-) diff --git a/app/lib/main.dart b/app/lib/main.dart index c8b6beb51..663084819 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -180,10 +180,10 @@ class _MyAppState extends State with WidgetsBindingObserver { (previous?..setDeviceProvider(value)) ?? OnboardingProvider(), ), ListenableProvider(create: (context) => HomeProvider()), - ChangeNotifierProxyProvider2( + ChangeNotifierProxyProvider( create: (context) => SpeechProfileProvider(), - update: (BuildContext context, device, capture, SpeechProfileProvider? previous) => - (previous?..setProviders(device, capture)) ?? SpeechProfileProvider(), + update: (BuildContext context, device, SpeechProfileProvider? previous) => + (previous?..setProviders(device)) ?? SpeechProfileProvider(), ), ChangeNotifierProxyProvider2( create: (context) => MemoryDetailProvider(), diff --git a/app/lib/pages/speech_profile/page.dart b/app/lib/pages/speech_profile/page.dart index bf46896d1..bdda4a00c 100644 --- a/app/lib/pages/speech_profile/page.dart +++ b/app/lib/pages/speech_profile/page.dart @@ -6,6 +6,7 @@ import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/bt_device.dart'; import 'package:friend_private/pages/home/page.dart'; import 'package:friend_private/pages/speech_profile/user_speech_samples.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/speech_profile_provider.dart'; import 'package:friend_private/services/services.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -30,6 +31,7 @@ class _SpeechProfilePageState extends State with TickerProvid WidgetsBinding.instance.addPostFrameCallback((timeStamp) async { await context.read().updateDevice(); }); + super.initState(); } @@ -71,10 +73,15 @@ class _SpeechProfilePageState extends State with TickerProvid if (context.read().isInitialised) { WidgetsBinding.instance.addPostFrameCallback((timeStamp) async { await context.read().close(); + + // Restart device recording + if (mounted && context.mounted) { + await Provider.of(context, listen: false).streamDeviceRecording(restartBytesProcessing: true); + } }); } }, - child: Consumer(builder: (context, provider, child) { + child: Consumer2(builder: (context, provider, _, child) { return MessageListener( showInfo: (info) { if (info == 'SCROLL_DOWN') { @@ -320,12 +327,24 @@ class _SpeechProfilePageState extends State with TickerProvid ); return; } + + // Stop device recoding + if (mounted && context.mounted) { + await Provider.of(context, listen: false) + .stopStreamDeviceRecording(); + } + await provider.initialise(false); - // provider.initiateWebsocket(false); // 1.5 minutes seems reasonable provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () { provider.finalize(); + + // Restart device recording + if (mounted && context.mounted) { + Provider.of(context, listen: false) + .streamDeviceRecording(); + } }); provider.updateStartedRecording(true); }, diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 3a9fcc548..1724f08ff 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -679,9 +679,13 @@ class CaptureProvider extends ChangeNotifier ); } - Future stopStreamDeviceRecording() async { - _updateRecordingDevice(null); - await _resetState(); + Future stopStreamDeviceRecording({bool cleanDevice = false}) async { + if (cleanDevice) { + _updateRecordingDevice(null); + } + _cleanupCurrentState(); + await _socket?.stop(reason: 'stop stream device recording'); + await _handleMemoryCreation(false); } // Socket handling diff --git a/app/lib/providers/device_provider.dart b/app/lib/providers/device_provider.dart index 3baa933e2..936594f8b 100644 --- a/app/lib/providers/device_provider.dart +++ b/app/lib/providers/device_provider.dart @@ -178,7 +178,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(null); setIsConnected(false); updateConnectingStatus(false); - await captureProvider?.stopStreamDeviceRecording(); + await captureProvider?.stopStreamDeviceRecording(cleanDevice: true); captureProvider?.setAudioBytesConnected(false); print('after resetState inside initiateConnectionListener'); diff --git a/app/lib/providers/speech_profile_provider.dart b/app/lib/providers/speech_profile_provider.dart index 51c894b91..736f157e8 100644 --- a/app/lib/providers/speech_profile_provider.dart +++ b/app/lib/providers/speech_profile_provider.dart @@ -26,7 +26,6 @@ class SpeechProfileProvider extends ChangeNotifier with MessageNotifierMixin implements IDeviceServiceSubsciption, ITransctipSegmentSocketServiceListener { DeviceProvider? deviceProvider; - CaptureProvider? captureProvider; bool? permissionEnabled; bool loading = false; BTDeviceStruct? device; @@ -76,9 +75,8 @@ class SpeechProfileProvider extends ChangeNotifier notifyListeners(); } - void setProviders(DeviceProvider provider, CaptureProvider captureProvider) { + void setProviders(DeviceProvider provider) { deviceProvider = provider; - this.captureProvider = captureProvider; notifyListeners(); } @@ -94,10 +92,8 @@ class SpeechProfileProvider extends ChangeNotifier _isFromOnboarding = isFromOnboarding; setInitialising(true); device = deviceProvider?.connectedDevice; - await captureProvider!.resetForSpeechProfile(); await _initiateWebsocket(force: true); - // _bleBytesStream = captureProvider?.bleBytesStream; if (device != null) await initiateFriendAudioStreaming(); if (_socket?.state != SocketServiceState.connected) { // wait for websocket to connect @@ -179,9 +175,9 @@ class SpeechProfileProvider extends ChangeNotifier SharedPreferencesUtil().hasSpeakerProfile = true; if (_isFromOnboarding) { await createMemory(); - captureProvider?.clearTranscripts(); + // TODO: thinh, socket + //captureProvider?.clearTranscripts(); } - await captureProvider?.streamDeviceRecording(restartBytesProcessing: true); uploadingProfile = false; profileCompleted = true; text = ''; @@ -210,7 +206,6 @@ class SpeechProfileProvider extends ChangeNotifier value.removeRange(0, 3); if (_socket?.state == SocketServiceState.connected) { - debugPrint("Socket stream ${value.length}"); _socket?.send(value); } }, diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 3c357e356..33189e4d3 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -146,11 +146,11 @@ class PureSocket implements IPureSocket { @override Future disconnect() async { - _status = PureSocketStatus.disconnected; if (_status == PureSocketStatus.connected) { // Warn: should not use await cause dead end by socket closed. _channel?.sink.close(socket_channel_status.normalClosure); } + _status = PureSocketStatus.disconnected; onClosed(); } @@ -194,7 +194,7 @@ class PureSocket implements IPureSocket { } void _reconnect() async { - debugPrint("[Socket] reconnect...${_retries+1}..."); + debugPrint("[Socket] reconnect...${_retries + 1}..."); const int initialBackoffTimeMs = 1000; // 1 second const double multiplier = 1.5; const int maxRetries = 7; From f8654cc578950863014613e6e2ede76cf0a7f2f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Fri, 20 Sep 2024 13:07:36 +0700 Subject: [PATCH 33/88] Move capture provider away from speech profile provider # onboarding --- app/lib/pages/home/page.dart | 4 + .../onboarding/speech_profile_widget.dart | 34 +++++++- app/lib/pages/speech_profile/page.dart | 36 ++++---- app/lib/providers/onboarding_provider.dart | 10 +++ .../providers/speech_profile_provider.dart | 82 ++++++++++--------- 5 files changed, 109 insertions(+), 57 deletions(-) diff --git a/app/lib/pages/home/page.dart b/app/lib/pages/home/page.dart index 2afa8183f..dde45dcb6 100644 --- a/app/lib/pages/home/page.dart +++ b/app/lib/pages/home/page.dart @@ -134,6 +134,10 @@ class _HomePageState extends State with WidgetsBindingObserver, Ticker ForegroundUtil.startForegroundTask(); if (mounted) { await context.read().setUserPeople(); + + // Start stream recording + await Provider.of(context, listen: false) + .streamDeviceRecording(device: context.read().connectedDevice); } }); diff --git a/app/lib/pages/onboarding/speech_profile_widget.dart b/app/lib/pages/onboarding/speech_profile_widget.dart index 01712aeac..e05a4f8b8 100644 --- a/app/lib/pages/onboarding/speech_profile_widget.dart +++ b/app/lib/pages/onboarding/speech_profile_widget.dart @@ -3,7 +3,10 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:flutter_provider_utilities/flutter_provider_utilities.dart'; import 'package:friend_private/backend/preferences.dart'; +import 'package:friend_private/backend/schema/bt_device.dart'; +import 'package:friend_private/providers/capture_provider.dart'; import 'package:friend_private/providers/speech_profile_provider.dart'; +import 'package:friend_private/services/services.dart'; import 'package:friend_private/widgets/dialog.dart'; import 'package:gradient_borders/box_borders/gradient_box_border.dart'; import 'package:provider/provider.dart'; @@ -49,13 +52,35 @@ class _SpeechProfileWidgetState extends State with TickerPr @override Widget build(BuildContext context) { + Future restartDeviceRecording() async { + debugPrint("restartDeviceRecording $mounted"); + + // Restart device recording, clear transcripts + if (mounted) { + Provider.of(context, listen: false).clearTranscripts(); + Provider.of(context, listen: false).streamDeviceRecording( + device: Provider.of(context, listen: false).deviceProvider?.connectedDevice, + ); + } + } + + Future stopDeviceRecording() async { + debugPrint("stopDeviceRecording $mounted"); + + // Restart device recording, clear transcripts + if (mounted) { + await Provider.of(context, listen: false).stopStreamDeviceRecording(); + } + } + return PopScope( canPop: true, - onPopInvoked: (didPop) { + onPopInvoked: (didPop) async { context.read().close(); + restartDeviceRecording(); }, - child: Consumer( - builder: (context, provider, child) { + child: Consumer2( + builder: (context, provider, _, child) { return MessageListener( showInfo: (info) { if (info == 'SCROLL_DOWN') { @@ -204,7 +229,8 @@ class _SpeechProfileWidgetState extends State with TickerPr ), child: TextButton( onPressed: () async { - await provider.initialise(true); + await stopDeviceRecording(); + await provider.initialise(true, finalizedCallback: restartDeviceRecording); provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () async { provider.finalize(); diff --git a/app/lib/pages/speech_profile/page.dart b/app/lib/pages/speech_profile/page.dart index bdda4a00c..198c27234 100644 --- a/app/lib/pages/speech_profile/page.dart +++ b/app/lib/pages/speech_profile/page.dart @@ -67,6 +67,23 @@ class _SpeechProfilePageState extends State with TickerProvid @override Widget build(BuildContext context) { + Future restartDeviceRecording() async { + debugPrint("restartDeviceRecording $mounted"); + if (mounted) { + Provider.of(context, listen: false).clearTranscripts(); + Provider.of(context, listen: false).streamDeviceRecording( + device: Provider.of(context, listen: false).deviceProvider?.connectedDevice, + ); + } + } + + Future stopDeviceRecording() async { + debugPrint("stopDeviceRecording $mounted"); + if (mounted) { + await Provider.of(context, listen: false).stopStreamDeviceRecording(); + } + } + return PopScope( canPop: true, onPopInvoked: (didPop) { @@ -75,9 +92,7 @@ class _SpeechProfilePageState extends State with TickerProvid await context.read().close(); // Restart device recording - if (mounted && context.mounted) { - await Provider.of(context, listen: false).streamDeviceRecording(restartBytesProcessing: true); - } + restartDeviceRecording(); }); } }, @@ -328,23 +343,12 @@ class _SpeechProfilePageState extends State with TickerProvid return; } - // Stop device recoding - if (mounted && context.mounted) { - await Provider.of(context, listen: false) - .stopStreamDeviceRecording(); - } - - await provider.initialise(false); + await stopDeviceRecording(); + await provider.initialise(false, finalizedCallback: restartDeviceRecording); // 1.5 minutes seems reasonable provider.forceCompletionTimer = Timer(Duration(seconds: provider.maxDuration), () { provider.finalize(); - - // Restart device recording - if (mounted && context.mounted) { - Provider.of(context, listen: false) - .streamDeviceRecording(); - } }); provider.updateStartedRecording(true); }, diff --git a/app/lib/providers/onboarding_provider.dart b/app/lib/providers/onboarding_provider.dart index 70d7a183d..d7b90723e 100644 --- a/app/lib/providers/onboarding_provider.dart +++ b/app/lib/providers/onboarding_provider.dart @@ -179,6 +179,7 @@ class OnboardingProvider extends BaseProvider with MessageNotifierMixin implemen deviceProvider!.setConnectedDevice(cDevice); SharedPreferencesUtil().btDeviceStruct = cDevice; SharedPreferencesUtil().deviceName = cDevice.name; + SharedPreferencesUtil().deviceCodec = await _getAudioCodec(device.id); deviceProvider!.setIsConnected(true); } //TODO: should'nt update codec here, becaause then the prev connection codec and the current codec will @@ -196,6 +197,7 @@ class OnboardingProvider extends BaseProvider with MessageNotifierMixin implemen await Future.delayed(const Duration(seconds: 2)); SharedPreferencesUtil().btDeviceStruct = connectedDevice!; SharedPreferencesUtil().deviceName = connectedDevice.name; + SharedPreferencesUtil().deviceCodec = await _getAudioCodec(device.id); foundDevicesMap.clear(); deviceList.clear(); if (isFromOnboarding) { @@ -250,6 +252,14 @@ class OnboardingProvider extends BaseProvider with MessageNotifierMixin implemen return connection?.device; } + Future _getAudioCodec(String deviceId) async { + var connection = await ServiceManager.instance().device.ensureConnection(deviceId); + if (connection == null) { + return BleAudioCodec.pcm8; + } + return connection.getAudioCodec(); + } + @override void dispose() { //TODO: This does not get called when the page is popped diff --git a/app/lib/providers/speech_profile_provider.dart b/app/lib/providers/speech_profile_provider.dart index 736f157e8..8efec243b 100644 --- a/app/lib/providers/speech_profile_provider.dart +++ b/app/lib/providers/speech_profile_provider.dart @@ -53,6 +53,7 @@ class SpeechProfileProvider extends ChangeNotifier String message = ''; late bool _isFromOnboarding; + late Function? _finalizedCallback; /// only used during onboarding ///// String loadingText = 'Uploading your voice profile....'; @@ -88,8 +89,9 @@ class SpeechProfileProvider extends ChangeNotifier notifyListeners(); } - Future initialise(bool isFromOnboarding) async { + Future initialise(bool isFromOnboarding, {Function? finalizedCallback}) async { _isFromOnboarding = isFromOnboarding; + _finalizedCallback = finalizedCallback; setInitialising(true); device = deviceProvider?.connectedDevice; await _initiateWebsocket(force: true); @@ -144,45 +146,49 @@ class SpeechProfileProvider extends ChangeNotifier } Future finalize() async { - if (uploadingProfile || profileCompleted) return; - - int duration = segments.isEmpty ? 0 : segments.last.end.toInt(); - if (duration < 5 || duration > 120) { - notifyError('INVALID_RECORDING'); - } + try { + if (uploadingProfile || profileCompleted) return; - String text = segments.map((e) => e.text).join(' ').trim(); - if (text.split(' ').length < (targetWordsCount / 2)) { - // 25 words - notifyError('TOO_SHORT'); - } - uploadingProfile = true; - notifyListeners(); - await _socket?.stop(reason: 'finalizing'); - forceCompletionTimer?.cancel(); - connectionStateListener?.cancel(); - _bleBytesStream?.cancel(); + int duration = segments.isEmpty ? 0 : segments.last.end.toInt(); + if (duration < 5 || duration > 120) { + notifyError('INVALID_RECORDING'); + } - updateLoadingText('Memorizing your voice...'); - List> raw = List.from(audioStorage.rawPackets); - var data = await audioStorage.createWavFile(filename: 'speaker_profile.wav'); - try { - await uploadProfile(data.item1); - await uploadProfileBytes(raw, duration); - } catch (e) {} - - updateLoadingText('Personalizing your experience...'); - SharedPreferencesUtil().hasSpeakerProfile = true; - if (_isFromOnboarding) { - await createMemory(); - // TODO: thinh, socket - //captureProvider?.clearTranscripts(); + String text = segments.map((e) => e.text).join(' ').trim(); + if (text.split(' ').length < (targetWordsCount / 2)) { + // 25 words + notifyError('TOO_SHORT'); + } + uploadingProfile = true; + notifyListeners(); + await _socket?.stop(reason: 'finalizing'); + forceCompletionTimer?.cancel(); + connectionStateListener?.cancel(); + _bleBytesStream?.cancel(); + + updateLoadingText('Memorizing your voice...'); + List> raw = List.from(audioStorage.rawPackets); + var data = await audioStorage.createWavFile(filename: 'speaker_profile.wav'); + try { + await uploadProfile(data.item1); + await uploadProfileBytes(raw, duration); + } catch (e) {} + + updateLoadingText('Personalizing your experience...'); + SharedPreferencesUtil().hasSpeakerProfile = true; + if (_isFromOnboarding) { + await createMemory(); + } + uploadingProfile = false; + profileCompleted = true; + text = ''; + updateLoadingText("You're all set!"); + notifyListeners(); + } finally { + if (_finalizedCallback != null) { + _finalizedCallback!(); + } } - uploadingProfile = false; - profileCompleted = true; - text = ''; - updateLoadingText("You're all set!"); - notifyListeners(); } // TODO: use connection directly @@ -328,8 +334,10 @@ class SpeechProfileProvider extends ChangeNotifier connectionStateListener?.cancel(); _bleBytesStream?.cancel(); forceCompletionTimer?.cancel(); + _finalizedCallback = null; _socket?.unsubscribe(this); ServiceManager.instance().device.unsubscribe(this); + super.dispose(); } From bb1105371db5a1b01e5138f04d760f723b3342d4 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Thu, 19 Sep 2024 23:41:22 -0700 Subject: [PATCH 34/88] speechmatics model recognizes speaker id --- backend/routers/transcribe.py | 32 ++++++++++++++++++++++++++++---- backend/utils/other/storage.py | 3 ++- backend/utils/stt/streaming.py | 30 ++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index a188f1aca..ed594df50 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -7,6 +7,7 @@ import requests from fastapi import APIRouter from fastapi.websockets import WebSocketDisconnect, WebSocket +from pydub import AudioSegment from starlette.websockets import WebSocketState import database.memories as memories_db @@ -205,6 +206,7 @@ def stream_audio(audio_buffer): soniox_socket = None speechmatics_socket = None + speechmatics_socket2 = None deepgram_socket = None deepgram_socket2 = None @@ -240,9 +242,20 @@ def stream_audio(audio_buffer): stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None ) elif stt_service == STTService.speechmatics: + file_path = None + if language == 'en' and codec == 'opus' and include_speech_profile: + file_path = get_profile_audio_if_exists(uid) + duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0 + speechmatics_socket = await process_audio_speechmatics( - stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None + stream_transcript, speech_profile_stream_id, language, preseconds=duration ) + if duration: + # speechmatics_socket2 = await process_audio_speechmatics( + # stream_transcript, speech_profile_stream_id, language, preseconds=duration + # ) + await send_initial_file_path(file_path, speechmatics_socket) + print('speech_profile speechmatics duration', duration) except Exception as e: print(f"Initial processing error: {e}") @@ -255,7 +268,7 @@ def stream_audio(audio_buffer): decoder = opuslib.Decoder(sample_rate, channels) - async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_socket): + async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_socket1): nonlocal websocket_active nonlocal websocket_close_code nonlocal timer_start @@ -279,9 +292,18 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock decoded_opus = decoder.decode(bytes(data), frame_size=160) await soniox_socket.send(decoded_opus) - if speechmatics_socket is not None: + if speechmatics_socket1 is not None: decoded_opus = decoder.decode(bytes(data), frame_size=160) - await speechmatics_socket.send(decoded_opus) + await speechmatics_socket1.send(decoded_opus) + + # elapsed_seconds = time.time() - timer_start + # if elapsed_seconds > duration or not dg_socket2: + # if speechmatics_socket2: + # print('Killing socket2 speechmatics') + # speechmatics_socket2.close() + # speechmatics_socket2 = None + # else: + # speechmatics_socket2.send(decoded_opus) if deepgram_socket is not None: elapsed_seconds = time.time() - timer_start @@ -314,6 +336,8 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock await soniox_socket.close() if speechmatics_socket: await speechmatics_socket.close() + if speechmatics_socket2: + await speechmatics_socket2.close() # heart beat async def send_heartbeat(): diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py index 7eb78b2c9..b0e13b915 100644 --- a/backend/utils/other/storage.py +++ b/backend/utils/other/storage.py @@ -149,6 +149,7 @@ def create_signed_postprocessing_audio_url(file_path: str): blob = bucket.blob(file_path) return _get_signed_url(blob, 15) + def upload_postprocessing_audio_bytes(file_path: str, audio_buffer: bytes): bucket = storage_client.bucket(postprocessing_audio_bucket) blob = bucket.blob(file_path) @@ -162,13 +163,13 @@ def upload_sdcard_audio(file_path: str): blob.upload_from_filename(file_path) return f'https://storage.googleapis.com/{postprocessing_audio_bucket}/sdcard/{file_path}' + def download_postprocessing_audio(file_path: str, destination_file_path: str): bucket = storage_client.bucket(postprocessing_audio_bucket) blob = bucket.blob(file_path) blob.download_to_filename(destination_file_path) - # ************************************************ # ************* MEMORIES RECORDINGS ************** # ************************************************ diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 303ff771e..03f96ec18 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -61,6 +61,22 @@ # return segments +async def send_initial_file_path(file_path: str, transcript_socket): + print('send_initial_file_path') + start = time.time() + # Reading and sending in chunks + with open(file_path, "rb") as file: + while True: + chunk = file.read(320) + if not chunk: + break + # print('Uploading', len(chunk)) + await transcript_socket.send(bytes(chunk)) + await asyncio.sleep(0.0001) # if it takes too long to transcribe + + print('send_initial_file_path', time.time() - start) + + async def send_initial_file(data: List[List[int]], transcript_socket): print('send_initial_file2') start = time.time() @@ -283,7 +299,7 @@ async def on_message(): CONNECTION_URL = f"wss://eu2.rt.speechmatics.com/v2" -async def process_audio_speechmatics(stream_transcript, stream_id: int, language: str, uid: str): +async def process_audio_speechmatics(stream_transcript, stream_id: int, language: str, preseconds: int = 0): # Create a transcription client api_key = os.getenv('SPEECHMATICS_API_KEY') uri = 'wss://eu2.rt.speechmatics.com/v2' @@ -334,9 +350,10 @@ async def on_message(): continue segments = [] for r in results: - # print(r) + print(r) if not r['alternatives']: continue + r_data = r['alternatives'][0] r_type = r['type'] # word | punctuation r_start = r['start_time'] @@ -349,6 +366,11 @@ async def on_message(): continue r_speaker = r_data['speaker'][1:] if r_data['speaker'] != 'UU' else '1' speaker = f"SPEAKER_0{r_speaker}" + + is_user = True if r_speaker == '1' and preseconds > 0 else False + if r_start < preseconds: + print('Skipping word', r_start, r_content) + continue # print(r_content, r_speaker, [r_start, r_end]) if not segments: segments.append({ @@ -356,7 +378,7 @@ async def on_message(): 'start': r_start, 'end': r_end, 'text': r_content, - 'is_user': False, + 'is_user': is_user, 'person_id': None, }) else: @@ -370,7 +392,7 @@ async def on_message(): 'start': r_start, 'end': r_end, 'text': r_content, - 'is_user': False, + 'is_user': is_user, 'person_id': None, }) From 819188fd5fb8bb6b26c27bfff249adcdd4df5873 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Fri, 20 Sep 2024 00:35:11 -0700 Subject: [PATCH 35/88] whisper x postprocessing extra logs --- backend/routers/transcribe.py | 3 ++- backend/utils/memories/postprocess_memory.py | 18 +++++++++++------- backend/utils/stt/vad.py | 2 ++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index ed594df50..681e8e898 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -133,10 +133,11 @@ async def _websocket_util( if stt_service == STTService.soniox and ( sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages): stt_service = STTService.deepgram - if stt_service == STTService.speechmatics and (sample_rate != 16000 or codec != 'opus'): stt_service = STTService.deepgram + # At some point try running all the models together to easily compare + # Check: Why do we need try-catch around websocket.accept? try: await websocket.accept() diff --git a/backend/utils/memories/postprocess_memory.py b/backend/utils/memories/postprocess_memory.py index 4cb1670cd..be9762f8a 100644 --- a/backend/utils/memories/postprocess_memory.py +++ b/backend/utils/memories/postprocess_memory.py @@ -19,32 +19,36 @@ def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): memory_data = _get_memory_by_id(uid, memory_id) if not memory_data: - return (404, "Memory not found") + return 404, "Memory not found" memory = Memory(**memory_data) if memory.discarded: print('postprocess_memory: Memory is discarded') - return (400, "Memory is discarded") + return 400, "Memory is discarded" if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') - return (400, "Memory can't be post-processed again") + return 400, "Memory can't be post-processed again" aseg = AudioSegment.from_wav(file_path) if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 # TODO: fix app, sometimes audio uploaded is wrong, is too short. print('postprocess_memory: Audio duration is too short, seems wrong.') memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) - return (500, "Audio duration is too short, seems wrong.") + return 500, "Audio duration is too short, seems wrong." memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) try: # Calling VAD to avoid processing empty parts and getting hallucinations from whisper. + # TODO: use this logs to determine if whisperx is failing because of the VAD results. + print('previous to vad_is_empty (segments duration):', + memory.transcript_segments[-1].end - memory.transcript_segments[0].start) vad_segments = vad_is_empty(file_path, return_segments=True) if vad_segments: start = vad_segments[0]['start'] end = vad_segments[-1]['end'] + print('vad_is_empty file result segments:', start, end) aseg = AudioSegment.from_wav(file_path) aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] aseg.export(file_path, format="wav") @@ -90,7 +94,7 @@ def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedb memory.postprocessing = MemoryPostProcessing( status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx) # TODO: consider doing process_memory, if any segment still matched to user or people - return (200, memory) + return 200, memory # Reprocess memory with improved transcription result: Memory = process_memory(uid, memory.language, memory, force_process=True) @@ -101,13 +105,13 @@ def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedb except Exception as e: print(e) memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) - return (500, str(e)) + return 500, str(e) memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) result.postprocessing = MemoryPostProcessing( status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) - return (200, result) + return 200, result def _get_memory_by_id(uid: str, memory_id: str) -> dict: diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index a5ca374d8..9870cbacf 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -38,6 +38,8 @@ def is_audio_empty(file_path, sample_rate=8000): def vad_is_empty(file_path, return_segments: bool = False): """Uses vad_modal/vad.py deployment (Best quality)""" try: + file_duration = AudioSegment.from_wav(file_path).duration_seconds + print('vad_is_empty file duration:', file_duration) with open(file_path, 'rb') as file: files = {'file': (file_path.split('/')[-1], file, 'audio/wav')} response = requests.post(os.getenv('HOSTED_VAD_API_URL'), files=files) From 4344e29ab6f8a454c9e499fb28383fb432ba8ac5 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Fri, 20 Sep 2024 00:39:29 -0700 Subject: [PATCH 36/88] commented `get_people_with_speech_samples` speech profile endpoint --- backend/modal/speech_profile_modal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/modal/speech_profile_modal.py b/backend/modal/speech_profile_modal.py index 89ed38513..521b3be68 100644 --- a/backend/modal/speech_profile_modal.py +++ b/backend/modal/speech_profile_modal.py @@ -130,8 +130,8 @@ def endpoint(uid: str, audio_file: UploadFile = File(...), segments: str = Form( segments_data = json.loads(segments) transcript_segments = [TranscriptSegment(**segment) for segment in segments_data] - people = get_people_with_speech_samples(uid) - + # people = get_people_with_speech_samples(uid) + people = [] try: result = classify_segments(audio_file.filename, profile_path, people, transcript_segments) print(result) From ea92c2cef4e15cf90dd00cdbef134df2d1e72ce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Fri, 20 Sep 2024 14:44:21 +0700 Subject: [PATCH 37/88] With keep-alived provider --- app/lib/providers/capture_provider.dart | 25 +++++++++++++++++++++++++ app/lib/utils/pure_socket.dart | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 1724f08ff..5148feabf 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -41,6 +41,8 @@ class CaptureProvider extends ChangeNotifier MessageProvider? messageProvider; TranscripSegmentSocketService? _socket; + Timer? _keepAliveTimer; + void updateProviderInstances(MemoryProvider? mp, MessageProvider? p) { memoryProvider = mp; messageProvider = p; @@ -636,6 +638,7 @@ class CaptureProvider extends ChangeNotifier _bleBytesStream?.cancel(); _memoryCreationTimer?.cancel(); _socket?.unsubscribe(this); + _keepAliveTimer?.cancel(); super.dispose(); } @@ -700,12 +703,34 @@ class CaptureProvider extends ChangeNotifier setMemoryCreating(false); setHasTranscripts(false); notifyListeners(); + + // Keep alived + _startKeepAlivedServices(); + } + + void _startKeepAlivedServices() { + if (_recordingDevice != null && _socket?.state != SocketServiceState.connected) { + _keepAliveTimer?.cancel(); + _keepAliveTimer = Timer.periodic(const Duration(seconds: 15), (t) async { + debugPrint("[Provider] keep alived..."); + + if (_recordingDevice == null || _socket?.state == SocketServiceState.connected) { + t.cancel(); + return; + } + + await _initiateWebsocket(); + }); + } } @override void onError(Object err) { debugPrint('err: $err'); notifyListeners(); + + // Keep alived + _startKeepAlivedServices(); } @override diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 33189e4d3..5660212e7 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -225,7 +225,7 @@ class PureSocket implements IPureSocket { @override void onInternetSatusChanged(InternetStatus status) { - debugPrint("[Socket] Internet connection changed $status"); + debugPrint("[Socket] Internet connection changed $status socket $_status"); _internetStatus = status; switch (status) { case InternetStatus.connected: From 92a73f9593c24de98639917dca1e4cc36fe4ff5a Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Fri, 20 Sep 2024 02:05:03 -0700 Subject: [PATCH 38/88] transcript segments performance script --- backend/models/transcript_segment.py | 40 ++++++ backend/routers/transcribe.py | 43 +------ .../stt/k_compare_transcripts_performance.py | 121 ++++++++++++++++-- backend/utils/stt/streaming.py | 5 +- 4 files changed, 158 insertions(+), 51 deletions(-) diff --git a/backend/models/transcript_segment.py b/backend/models/transcript_segment.py index 6fbb99d0a..617de0dd2 100644 --- a/backend/models/transcript_segment.py +++ b/backend/models/transcript_segment.py @@ -42,6 +42,46 @@ def can_display_seconds(segments): return False return True + @staticmethod + def combine_segments(segments: [], new_segments: [], delta_seconds: int = 0): + if not new_segments or len(new_segments) == 0: + return segments + + joined_similar_segments = [] + for new_segment in new_segments: + if delta_seconds > 0: + new_segment.start += delta_seconds + new_segment.end += delta_seconds + + if (joined_similar_segments and + (joined_similar_segments[-1].speaker == new_segment.speaker or + (joined_similar_segments[-1].is_user and new_segment.is_user))): + joined_similar_segments[-1].text += f' {new_segment.text}' + joined_similar_segments[-1].end = new_segment.end + else: + joined_similar_segments.append(new_segment) + + if (segments and + (segments[-1].speaker == joined_similar_segments[0].speaker or + (segments[-1].is_user and joined_similar_segments[0].is_user)) and + (joined_similar_segments[0].start - segments[-1].end < 30)): + segments[-1].text += f' {joined_similar_segments[0].text}' + segments[-1].end = joined_similar_segments[0].end + joined_similar_segments.pop(0) + + segments.extend(joined_similar_segments) + + # Speechmatics specific issue with punctuation + for i, segment in enumerate(segments): + segments[i].text = ( + segments[i].text.strip() + .replace(' ', '') + .replace(' ,', ',') + .replace(' .', '.') + .replace(' ?', '?') + ) + return segments + class ImprovedTranscriptSegment(BaseModel): speaker_id: int = Field(..., description='The correctly assigned speaker id') diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 681e8e898..bf10897f3 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -68,45 +68,6 @@ # def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: # -def _combine_segments(segments: [], new_segments: [], delta_seconds: int = 0): - if not new_segments or len(new_segments) == 0: - return segments - - joined_similar_segments = [] - for new_segment in new_segments: - if delta_seconds > 0: - new_segment.start += delta_seconds - new_segment.end += delta_seconds - - if (joined_similar_segments and - (joined_similar_segments[-1].speaker == new_segment.speaker or - (joined_similar_segments[-1].is_user and new_segment.is_user))): - joined_similar_segments[-1].text += f' {new_segment.text}' - joined_similar_segments[-1].end = new_segment.end - else: - joined_similar_segments.append(new_segment) - - if (segments and - (segments[-1].speaker == joined_similar_segments[0].speaker or - (segments[-1].is_user and joined_similar_segments[0].is_user)) and - (joined_similar_segments[0].start - segments[-1].end < 30)): - segments[-1].text += f' {joined_similar_segments[0].text}' - segments[-1].end = joined_similar_segments[0].end - joined_similar_segments.pop(0) - - segments.extend(joined_similar_segments) - - # Speechmatics specific issue with punctuation - for i, segment in enumerate(segments): - segments[i].text = ( - segments[i].text.strip() - .replace(' ', '') - .replace(' ,', ',') - .replace(' .', '.') - .replace(' ?', '?') - ) - return segments - class STTService(str, Enum): deepgram = "deepgram" @@ -187,7 +148,7 @@ def stream_transcript(segments, stream_id): delta_seconds = 0 if processing_memory and processing_memory.timer_start > 0: delta_seconds = timer_start - processing_memory.timer_start - memory_transcript_segements = _combine_segments( + memory_transcript_segements = TranscriptSegment.combine_segments( memory_transcript_segements, list(map(lambda m: TranscriptSegment(**m), segments)), delta_seconds ) @@ -412,7 +373,7 @@ async def _create_processing_memory(): processing_memory.timer_starts.append(timer_start) # Transcript with delta - memory_transcript_segements = _combine_segments( + memory_transcript_segements = TranscriptSegment.combine_segments( processing_memory.transcript_segments, memory_transcript_segements, timer_start - processing_memory.timer_start ) diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index 7f26a6914..bdb57c2d8 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -1,17 +1,122 @@ # STEPS -# - get all users -# - get all memories non discarded -# - filter memories with audio recording available -# - filter again by ones that have whisperx + deepgram segments or soniox segments -# - store local json with data - -# - P2 +# - Download all files. +# - get each of those memories # - read local json with each memory audio file # - call whisper groq (whisper-largev3) # - Create a table df, well printed, with each transcript result side by side # - prompt for computing WER using groq whisper as baseline (if better, but most likely) # - Run for deepgram vs soniox, and generate comparison result - +import asyncio # - P3 # - Include speechmatics to the game +import json +import os +from typing import Dict, List + +import firebase_admin +from dotenv import load_dotenv +from pydub import AudioSegment + +load_dotenv('../../.dev.env') +os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS') + +firebase_admin.initialize_app() + +from models.transcript_segment import TranscriptSegment +from utils.stt.streaming import process_audio_dg, process_audio_soniox, process_audio_speechmatics + + +def store_model_result(memory_id: str, model: str, result: List[Dict]): + file_path = 'results.json' + if os.path.exists(file_path): + with open(file_path, 'r') as f: + results = json.load(f) + else: + results = {} + + if memory_id not in results: + results[memory_id] = {} + + results[memory_id][model] = result + # save it + with open(file_path, 'w') as f: + json.dump(results, f) + + +def add_model_result_segments(memory_id: str, model: str, result: List[Dict]): + file_path = 'results.json' + if os.path.exists(file_path): + with open(file_path, 'r') as f: + results = json.load(f) + else: + results = {} + + if memory_id not in results: + results[memory_id] = {} + + if model not in results[memory_id]: + results[memory_id][model] = [] + + segments = [TranscriptSegment(**s) for s in results[memory_id][model]] + new_segments = [TranscriptSegment(**s) for s in result] + + segments = TranscriptSegment.combine_segments(segments, new_segments) + store_model_result(memory_id, model, [s.dict() for s in segments]) + + +async def process_memories_audio_files(): + uids = os.listdir('_temp2') + for uid in uids: + memories = os.listdir(f'_temp2/{uid}') + memories = [f'_temp2/{uid}/{memory}' for memory in memories] + # memories_id = [] + # for file_path in memories: + # if AudioSegment.from_wav(file_path).frame_rate != 16000: + # continue + # memory_id = file_path.split('.')[0] + # memories_id.append(memory_id) + + # memories_data = get_memories_by_id(uid, memories_id) + for file_path in memories: + memory_id = file_path.split('/')[-1].split('.')[0] + print(memory_id) + + def stream_transcript_deepgram(new_segments, _): + print(new_segments) + add_model_result_segments(memory_id, 'deepgram', new_segments) + + def stream_transcript_soniox(new_segments, _): + print(new_segments) + add_model_result_segments(memory_id, 'soniox', new_segments) + + def stream_transcript_speechmatics(new_segments, _): + print(new_segments) + add_model_result_segments(memory_id, 'speechmatics', new_segments) + + socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0) + socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 'en', None) + socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 'en', 0) + duration = AudioSegment.from_wav(file_path).duration_seconds + print('duration', duration) + with open(file_path, "rb") as file: + + while True: + chunk = file.read(320) + if not chunk: + break + # print('Uploading', len(chunk)) + # TODO: Race conditions here? + socket.send(bytes(chunk)) + await socket_soniox.send(bytes(chunk)) + await socket_speechmatics.send(bytes(chunk)) + await asyncio.sleep(0.0001) + print('File sent') + # - call for whisper-x + # - store in a json file and cache + await asyncio.sleep(duration) # TODO: await duration + break + break + +if __name__ == '__main__': + asyncio.run(process_memories_audio_files()) diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 03f96ec18..e111c825c 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -198,7 +198,7 @@ async def process_audio_soniox(stream_transcript, stream_id: int, language: str, if language not in soniox_valid_languages: raise ValueError(f"Unsupported language '{language}'. Supported languages are: {soniox_valid_languages}") - has_speech_profile = create_user_speech_profile(uid) # only english too + has_speech_profile = create_user_speech_profile(uid) if uid else False # only english too # Construct the initial request with all required and optional parameters request = { @@ -232,6 +232,7 @@ async def on_message(): try: async for message in soniox_socket: response = json.loads(message) + # print(response) fw = response['fw'] if not fw: continue @@ -350,7 +351,7 @@ async def on_message(): continue segments = [] for r in results: - print(r) + # print(r) if not r['alternatives']: continue From f6165f2625be768fc8027dd18b0340705cc8c548 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:24:29 +0530 Subject: [PATCH 39/88] create events for memories when they are fetched --- app/lib/providers/memory_provider.dart | 36 ++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/app/lib/providers/memory_provider.dart b/app/lib/providers/memory_provider.dart index b0b12c02b..7aaaf47ed 100644 --- a/app/lib/providers/memory_provider.dart +++ b/app/lib/providers/memory_provider.dart @@ -1,8 +1,11 @@ +import 'package:collection/collection.dart'; import 'package:flutter/foundation.dart'; import 'package:friend_private/backend/http/api/memories.dart'; import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/memory.dart'; +import 'package:friend_private/backend/schema/structured.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; +import 'package:friend_private/utils/features/calendar.dart'; class MemoryProvider extends ChangeNotifier { List memories = []; @@ -88,11 +91,22 @@ class MemoryProvider extends ChangeNotifier { setLoadingMemories(true); var mem = await getMemories(); memories = mem; + createEventsForMemories(); setLoadingMemories(false); notifyListeners(); return memories; } + void createEventsForMemories() { + for (var memory in memories) { + if (memory.structured.events.isNotEmpty && + !memory.structured.events.first.created && + memory.startedAt!.isAfter(DateTime.now().add(const Duration(days: -1)))) { + _handleCalendarCreation(memory); + } + } + } + Future getMoreMemoriesFromServer() async { if (memories.length % 50 != 0) return; if (isLoadingMemories) return; @@ -123,6 +137,28 @@ class MemoryProvider extends ChangeNotifier { notifyListeners(); } + _handleCalendarCreation(ServerMemory memory) { + if (!SharedPreferencesUtil().calendarEnabled) return; + if (SharedPreferencesUtil().calendarType != 'auto') return; + + List events = memory.structured.events; + if (events.isEmpty) return; + + List indexes = events.mapIndexed((index, e) => index).toList(); + setMemoryEventsState(memory.id, indexes, indexes.map((_) => true).toList()); + for (var i = 0; i < events.length; i++) { + print('Creating event: ${events[i].title}'); + if (events[i].created) continue; + events[i].created = true; + CalendarUtil().createEvent( + events[i].title, + events[i].startsAt, + events[i].duration, + description: events[i].description, + ); + } + } + ///////////////////////////////////////////////////////////////// ////////// Delete Memory With Undo Functionality /////////////// From 551a26a7a7ba9f5423f56b31eea0aceefb43a6cb Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:25:31 +0530 Subject: [PATCH 40/88] call _handleCalendarCreation in _processOnMemoryCreated func --- app/lib/providers/capture_provider.dart | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 5148feabf..50c74f3b8 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -204,6 +204,7 @@ class CaptureProvider extends ChangeNotifier // Notify setMemoryCreating(false); setHasTranscripts(false); + _handleCalendarCreation(memory); notifyListeners(); return; } @@ -324,6 +325,7 @@ class CaptureProvider extends ChangeNotifier List indexes = events.mapIndexed((index, e) => index).toList(); setMemoryEventsState(memory.id, indexes, indexes.map((_) => true).toList()); for (var i = 0; i < events.length; i++) { + if (events[i].created) continue; events[i].created = true; CalendarUtil().createEvent( events[i].title, From eb1b0aa053b796230271c44e5080dac4dd6ee774 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:33:40 +0530 Subject: [PATCH 41/88] delete messages using firebase batch on backend --- backend/database/chat.py | 42 ++++++++++++++++++++++++---------------- backend/routers/chat.py | 3 +-- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index 413604e71..aa27609c8 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from typing import Optional +from fastapi import HTTPException from google.cloud import firestore from models.chat import Message @@ -82,20 +83,27 @@ def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: b ] return messages - - -def clear_chat(uid,batch_size): - user_ref = db.collection('users').document(uid) - messages_ref = user_ref.collection('messages') - if batch_size == 0: - return - docs = messages_ref.list_documents(page_size=batch_size) - deleted = 0 - - for doc in docs: - print(f"Deleting doc {doc.id} => {doc.get().to_dict()}") - doc.delete() - deleted = deleted + 1 - - if deleted >= batch_size: - return clear_chat(uid,batch_size) \ No newline at end of file + + +async def batch_delete_messages(parent_doc_ref, batch_size=450): + # batch size is 450 because firebase can perform upto 500 operations in a batch + messages_ref = parent_doc_ref.collection('messages') + while True: + docs = messages_ref.limit(batch_size).stream() + docs_list = list(docs) + if not docs_list: + break + batch = db.batch() + for doc in docs_list: + batch.delete(doc.reference) + batch.commit() + + +async def clear_chat( uid: str): + try: + user_ref = db.collection('users').document(uid) + if not user_ref.get().exists: + raise HTTPException(status_code=404, detail="User not found") + await batch_delete_messages(user_ref) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error deleting messages: {str(e)}") \ No newline at end of file diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 3d9b6efde..2836f1c19 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -57,8 +57,7 @@ def send_message( @router.delete('/v1/clear-chat', tags=['chat'], response_model=Message) def clear_chat(uid: str = Depends(auth.get_current_user_uid)): - - chat_db.clear_chat(uid, 400) + chat_db.clear_chat(uid) return initial_message_util(uid) From 9cbb312de99fd3b79283c7eaa8aef44726ce95df Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:34:12 +0530 Subject: [PATCH 42/88] improve experience --- app/lib/providers/message_provider.dart | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/app/lib/providers/message_provider.dart b/app/lib/providers/message_provider.dart index e12952ab4..838288ce4 100644 --- a/app/lib/providers/message_provider.dart +++ b/app/lib/providers/message_provider.dart @@ -10,11 +10,25 @@ class MessageProvider extends ChangeNotifier { List messages = []; bool isLoadingMessages = false; + bool hasCachedMessages = false; + bool isClearingChat = false; + + String firstTimeLoadingText = ''; void updatePluginProvider(PluginProvider p) { pluginProvider = p; } + void setHasCachedMessages(bool value) { + hasCachedMessages = value; + notifyListeners(); + } + + void setClearingChat(bool value) { + isClearingChat = value; + notifyListeners(); + } + void setLoadingMessages(bool value) { isLoadingMessages = value; notifyListeners(); @@ -32,9 +46,25 @@ class MessageProvider extends ChangeNotifier { notifyListeners(); } + void setMessagesFromCache() { + if (SharedPreferencesUtil().cachedMessages.isNotEmpty) { + setHasCachedMessages(true); + messages = SharedPreferencesUtil().cachedMessages; + } + notifyListeners(); + } + Future> getMessagesFromServer() async { + if (!hasCachedMessages) { + firstTimeLoadingText = 'Reading your memories...'; + notifyListeners(); + } setLoadingMessages(true); var mes = await getMessagesServer(); + if (!hasCachedMessages) { + firstTimeLoadingText = 'Learning from your memories...'; + notifyListeners(); + } messages = mes; setLoadingMessages(false); notifyListeners(); @@ -42,8 +72,10 @@ class MessageProvider extends ChangeNotifier { } Future clearChat() async { + setClearingChat(true); var mes = await clearChatServer(); messages = mes; + setClearingChat(false); notifyListeners(); } From 83c747d14993718bdbf93566a7c0c3f4b41436f9 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:34:23 +0530 Subject: [PATCH 43/88] fetch messages early --- app/lib/main.dart | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/app/lib/main.dart b/app/lib/main.dart index e55d2bd69..d3e49c1f5 100644 --- a/app/lib/main.dart +++ b/app/lib/main.dart @@ -275,6 +275,12 @@ class _DeciderWidgetState extends State { if (context.read().isConnected) { NotificationService.instance.saveNotificationToken(); } + + if (context.read().user != null) { + context.read().setMessagesFromCache(); + + context.read().refreshMessages(); + } }); super.initState(); } From 2c52e5bd50dfbb66873a7e2af6e2f1b2a89ac68c Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:37:20 +0530 Subject: [PATCH 44/88] improve ux by showing cached messages while fetching messages in bg --- app/lib/pages/chat/page.dart | 293 ++++++++++++++++++++--------------- 1 file changed, 169 insertions(+), 124 deletions(-) diff --git a/app/lib/pages/chat/page.dart b/app/lib/pages/chat/page.dart index 223aca73a..eb77b2d9e 100644 --- a/app/lib/pages/chat/page.dart +++ b/app/lib/pages/chat/page.dart @@ -9,6 +9,7 @@ import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/backend/schema/message.dart'; import 'package:friend_private/backend/schema/plugin.dart'; import 'package:friend_private/pages/chat/widgets/ai_message.dart'; +import 'package:friend_private/pages/chat/widgets/animated_mini_banner.dart'; import 'package:friend_private/pages/chat/widgets/user_message.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/providers/home_provider.dart'; @@ -50,7 +51,7 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { void initState() { plugins = prefs.pluginsList; SchedulerBinding.instance.addPostFrameCallback((_) async { - await context.read().refreshMessages(); + // await context.read().refreshMessages(); scrollToBottom(); }); // _initDailySummary(); @@ -70,134 +71,176 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { print('ChatPage build'); return Consumer2( builder: (context, provider, connectivityProvider, child) { - return Stack( - children: [ - Align( - alignment: Alignment.topCenter, - child: provider.isLoadingMessages - ? const Padding( - padding: EdgeInsets.only(top: 32.0), - child: CircularProgressIndicator( - color: Colors.white, - ), - ) - : (provider.messages.isEmpty) - ? Text( - connectivityProvider.isConnected - ? 'No messages yet!\nWhy don\'t you start a conversation?' - : 'Please check your internet connection and try again', - textAlign: TextAlign.center, - style: const TextStyle(color: Colors.white)) - : ListView.builder( - shrinkWrap: true, - reverse: true, - controller: scrollController, - // physics: const NeverScrollableScrollPhysics(), - itemCount: provider.messages.length, - itemBuilder: (context, chatIndex) { - final message = provider.messages[chatIndex]; - double topPadding = chatIndex == provider.messages.length - 1 ? 24 : 16; - double bottomPadding = chatIndex == 0 - ? Platform.isAndroid - ? 200 - : 170 - : 0; - return Padding( - key: ValueKey(message.id), - padding: EdgeInsets.only(bottom: bottomPadding, left: 18, right: 18, top: topPadding), - child: message.sender == MessageSender.ai - ? AIMessage( - message: message, - sendMessage: _sendMessageUtil, - displayOptions: provider.messages.length <= 1, - pluginSender: plugins.firstWhereOrNull((e) => e.id == message.pluginId), - updateMemory: (ServerMemory memory) { - context.read().updateMemory(memory); - }, - ) - : HumanMessage(message: message), - ); - }, - ), + return Scaffold( + backgroundColor: Theme.of(context).colorScheme.primary, + appBar: AnimatedMiniBanner( + showAppBar: provider.isLoadingMessages, + child: Container( + width: double.infinity, + height: 10, + color: Colors.green, + child: const Center( + child: Text( + 'Syncing messages with server...', + style: TextStyle(color: Colors.white, fontSize: 14), + ), + ), ), - Consumer(builder: (context, home, child) { - return Align( - alignment: Alignment.bottomCenter, - child: Container( - width: double.maxFinite, - padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 2), - margin: EdgeInsets.only(left: 32, right: 32, bottom: home.isChatFieldFocused ? 40 : 120), - decoration: const BoxDecoration( - color: Colors.black, - borderRadius: BorderRadius.all(Radius.circular(16)), - border: GradientBoxBorder( - gradient: LinearGradient(colors: [ - Color.fromARGB(127, 208, 208, 208), - Color.fromARGB(127, 188, 99, 121), - Color.fromARGB(127, 86, 101, 182), - Color.fromARGB(127, 126, 190, 236) - ]), - width: 1, - ), - shape: BoxShape.rectangle, - ), - child: TextField( - enabled: true, - controller: textController, - // textCapitalization: TextCapitalization.sentences, - obscureText: false, - focusNode: home.chatFieldFocusNode, - // canRequestFocus: true, - textAlign: TextAlign.start, - textAlignVertical: TextAlignVertical.center, - decoration: InputDecoration( - hintText: 'Ask your Friend anything', - hintStyle: const TextStyle(fontSize: 14.0, color: Colors.grey), - focusedBorder: InputBorder.none, - enabledBorder: InputBorder.none, - suffixIcon: IconButton( - splashColor: Colors.transparent, - splashRadius: 1, - onPressed: loading - ? null - : () async { - String message = textController.text; - if (message.isEmpty) return; - if (connectivityProvider.isConnected) { - _sendMessageUtil(message); - } else { - ScaffoldMessenger.of(context).showSnackBar( - const SnackBar( - content: Text('Please check your internet connection and try again'), - duration: Duration(seconds: 2), - ), - ); - } - }, - icon: loading - ? const SizedBox( - width: 16, - height: 16, - child: CircularProgressIndicator( - valueColor: AlwaysStoppedAnimation(Colors.white), - ), + ), + body: Stack( + children: [ + Align( + alignment: Alignment.topCenter, + child: provider.isLoadingMessages && !provider.hasCachedMessages + ? Column( + children: [ + const SizedBox(height: 100), + const CircularProgressIndicator( + valueColor: AlwaysStoppedAnimation(Colors.white), + ), + const SizedBox(height: 16), + Text( + provider.firstTimeLoadingText, + style: const TextStyle(color: Colors.white), + ), + ], + ) + : provider.isClearingChat + ? const Column( + children: [ + SizedBox(height: 100), + CircularProgressIndicator( + valueColor: AlwaysStoppedAnimation(Colors.white), + ), + SizedBox(height: 16), + Text( + "Deleting your messages from Omi's memory...", + style: TextStyle(color: Colors.white), + ), + ], + ) + : (provider.messages.isEmpty) + ? Padding( + padding: const EdgeInsets.only(top: 16.0), + child: Text( + connectivityProvider.isConnected + ? 'No messages yet!\nWhy don\'t you start a conversation?' + : 'Please check your internet connection and try again', + textAlign: TextAlign.center, + style: const TextStyle(color: Colors.white)), ) - : const Icon( - Icons.send_rounded, - color: Color(0xFFF7F4F4), - size: 24.0, + : ListView.builder( + shrinkWrap: true, + reverse: true, + controller: scrollController, + // physics: const NeverScrollableScrollPhysics(), + itemCount: provider.messages.length, + itemBuilder: (context, chatIndex) { + final message = provider.messages[chatIndex]; + double topPadding = chatIndex == provider.messages.length - 1 ? 24 : 16; + double bottomPadding = chatIndex == 0 + ? Platform.isAndroid + ? 200 + : 170 + : 0; + return Padding( + key: ValueKey(message.id), + padding: + EdgeInsets.only(bottom: bottomPadding, left: 18, right: 18, top: topPadding), + child: message.sender == MessageSender.ai + ? AIMessage( + message: message, + sendMessage: _sendMessageUtil, + displayOptions: provider.messages.length <= 1, + pluginSender: plugins.firstWhereOrNull((e) => e.id == message.pluginId), + updateMemory: (ServerMemory memory) { + context.read().updateMemory(memory); + }, + ) + : HumanMessage(message: message), + ); + }, ), + ), + Consumer(builder: (context, home, child) { + return Align( + alignment: Alignment.bottomCenter, + child: Container( + width: double.maxFinite, + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 2), + margin: EdgeInsets.only(left: 32, right: 32, bottom: home.isChatFieldFocused ? 40 : 120), + decoration: const BoxDecoration( + color: Colors.black, + borderRadius: BorderRadius.all(Radius.circular(16)), + border: GradientBoxBorder( + gradient: LinearGradient(colors: [ + Color.fromARGB(127, 208, 208, 208), + Color.fromARGB(127, 188, 99, 121), + Color.fromARGB(127, 86, 101, 182), + Color.fromARGB(127, 126, 190, 236) + ]), + width: 1, ), + shape: BoxShape.rectangle, + ), + child: TextField( + enabled: true, + controller: textController, + // textCapitalization: TextCapitalization.sentences, + obscureText: false, + focusNode: home.chatFieldFocusNode, + // canRequestFocus: true, + textAlign: TextAlign.start, + textAlignVertical: TextAlignVertical.center, + decoration: InputDecoration( + hintText: 'Ask your Friend anything', + hintStyle: const TextStyle(fontSize: 14.0, color: Colors.grey), + focusedBorder: InputBorder.none, + enabledBorder: InputBorder.none, + suffixIcon: IconButton( + splashColor: Colors.transparent, + splashRadius: 1, + onPressed: loading + ? null + : () async { + String message = textController.text; + if (message.isEmpty) return; + if (connectivityProvider.isConnected) { + _sendMessageUtil(message); + } else { + ScaffoldMessenger.of(context).showSnackBar( + const SnackBar( + content: Text('Please check your internet connection and try again'), + duration: Duration(seconds: 2), + ), + ); + } + }, + icon: loading + ? const SizedBox( + width: 16, + height: 16, + child: CircularProgressIndicator( + valueColor: AlwaysStoppedAnimation(Colors.white), + ), + ) + : const Icon( + Icons.send_rounded, + color: Color(0xFFF7F4F4), + size: 24.0, + ), + ), + ), + // maxLines: 8, + // minLines: 1, + // keyboardType: TextInputType.multiline, + style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), ), - // maxLines: 8, - // minLines: 1, - // keyboardType: TextInputType.multiline, - style: TextStyle(fontSize: 14.0, color: Colors.grey.shade200), ), - ), - ); - }), - ], + ); + }), + ], + ), ); }, ); @@ -223,7 +266,9 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { changeLoadingState(); scrollToBottom(); ServerMessage message = await getInitialPluginMessage(plugin?.id); - context.read().addMessage(message); + if (mounted) { + context.read().addMessage(message); + } scrollToBottom(); changeLoadingState(); } From 4c4fdae2dd7ed0016d965b1b41122401675a6114 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 20:00:04 +0530 Subject: [PATCH 45/88] animated mini banner widget --- .../chat/widgets/animated_mini_banner.dart | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 app/lib/pages/chat/widgets/animated_mini_banner.dart diff --git a/app/lib/pages/chat/widgets/animated_mini_banner.dart b/app/lib/pages/chat/widgets/animated_mini_banner.dart new file mode 100644 index 000000000..ccff5e758 --- /dev/null +++ b/app/lib/pages/chat/widgets/animated_mini_banner.dart @@ -0,0 +1,20 @@ +import 'package:flutter/material.dart'; + +class AnimatedMiniBanner extends StatelessWidget implements PreferredSizeWidget { + const AnimatedMiniBanner({super.key, required this.showAppBar, required this.child}); + + final bool showAppBar; + final Widget child; + + @override + Widget build(BuildContext context) { + return AnimatedContainer( + height: showAppBar ? kToolbarHeight : 0, + duration: const Duration(milliseconds: 400), + child: child, + ); + } + + @override + Size get preferredSize => const Size.fromHeight(30); +} From f38bd5852d6d1ac556462090543d176ea3df584f Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:06:12 +0530 Subject: [PATCH 46/88] improve memory in chat loading speed by only fetching it if it does not exist in memoryProvider --- app/lib/pages/chat/widgets/ai_message.dart | 76 +++++++++++++++------- app/lib/providers/memory_provider.dart | 28 +++++++- 2 files changed, 80 insertions(+), 24 deletions(-) diff --git a/app/lib/pages/chat/widgets/ai_message.dart b/app/lib/pages/chat/widgets/ai_message.dart index 393b89bc8..95159218a 100644 --- a/app/lib/pages/chat/widgets/ai_message.dart +++ b/app/lib/pages/chat/widgets/ai_message.dart @@ -9,7 +9,9 @@ import 'package:friend_private/backend/preferences.dart'; import 'package:friend_private/backend/schema/memory.dart'; import 'package:friend_private/backend/schema/message.dart'; import 'package:friend_private/backend/schema/plugin.dart'; +import 'package:friend_private/pages/memory_detail/memory_detail_provider.dart'; import 'package:friend_private/pages/memory_detail/page.dart'; +import 'package:friend_private/providers/memory_provider.dart'; import 'package:friend_private/utils/analytics/mixpanel.dart'; import 'package:friend_private/providers/connectivity_provider.dart'; import 'package:friend_private/utils/other/temp.dart'; @@ -127,30 +129,58 @@ class _AIMessageState extends State { onTap: () async { final connectivityProvider = Provider.of(context, listen: false); if (connectivityProvider.isConnected) { - if (memoryDetailLoading[data.$1]) return; - setState(() => memoryDetailLoading[data.$1] = true); + var memProvider = Provider.of(context, listen: false); + var idx = memProvider.memoriesWithDates.indexWhere((e) { + if (e.runtimeType == ServerMemory) { + return e.id == data.$2.id; + } + return false; + }); - ServerMemory? m = await getMemoryById(data.$2.id); - if (m == null) return; - MixpanelManager().chatMessageMemoryClicked(m); - setState(() => memoryDetailLoading[data.$1] = false); - await Navigator.of(context) - .push(MaterialPageRoute(builder: (c) => MemoryDetailPage(memory: m))); - if (SharedPreferencesUtil().modifiedMemoryDetails?.id == m.id) { - ServerMemory modifiedDetails = SharedPreferencesUtil().modifiedMemoryDetails!; - widget.updateMemory(SharedPreferencesUtil().modifiedMemoryDetails!); - var copy = List.from(widget.message.memories); - copy[data.$1] = MessageMemory( - modifiedDetails.id, - modifiedDetails.createdAt, - MessageMemoryStructured( - modifiedDetails.structured.title, - modifiedDetails.structured.emoji, - )); - widget.message.memories.clear(); - widget.message.memories.addAll(copy); - SharedPreferencesUtil().modifiedMemoryDetails = null; - setState(() {}); + if (idx != -1) { + context.read().updateMemory(idx); + var m = memProvider.memoriesWithDates[idx]; + MixpanelManager().chatMessageMemoryClicked(m); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (c) => MemoryDetailPage( + memory: m, + ), + ), + ); + } else { + if (memoryDetailLoading[data.$1]) return; + setState(() => memoryDetailLoading[data.$1] = true); + ServerMemory? m = await getMemoryById(data.$2.id); + if (m == null) return; + idx = memProvider.addMemoryWithDate(m); + MixpanelManager().chatMessageMemoryClicked(m); + setState(() => memoryDetailLoading[data.$1] = false); + context.read().updateMemory(idx); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (c) => MemoryDetailPage( + memory: m, + ), + ), + ); + //TODO: Not needed anymore I guess because memories are stored in provider and read from there only + if (SharedPreferencesUtil().modifiedMemoryDetails?.id == m.id) { + ServerMemory modifiedDetails = SharedPreferencesUtil().modifiedMemoryDetails!; + widget.updateMemory(SharedPreferencesUtil().modifiedMemoryDetails!); + var copy = List.from(widget.message.memories); + copy[data.$1] = MessageMemory( + modifiedDetails.id, + modifiedDetails.createdAt, + MessageMemoryStructured( + modifiedDetails.structured.title, + modifiedDetails.structured.emoji, + )); + widget.message.memories.clear(); + widget.message.memories.addAll(copy); + SharedPreferencesUtil().modifiedMemoryDetails = null; + setState(() {}); + } } } else { ScaffoldMessenger.of(context).showSnackBar( diff --git a/app/lib/providers/memory_provider.dart b/app/lib/providers/memory_provider.dart index b0b12c02b..b5808cbbe 100644 --- a/app/lib/providers/memory_provider.dart +++ b/app/lib/providers/memory_provider.dart @@ -106,8 +106,34 @@ class MemoryProvider extends ChangeNotifier { void addMemory(ServerMemory memory) { memories.insert(0, memory); - filterMemories(''); + initFilteredMemories(); + notifyListeners(); + } + + int addMemoryWithDate(ServerMemory memory) { + int idx; + var date = memoriesWithDates.indexWhere((element) => + element is DateTime && + element.day == memory.createdAt.day && + element.month == memory.createdAt.month && + element.year == memory.createdAt.year); + if (date != -1) { + var hour = memoriesWithDates[date + 1].createdAt.hour; + var newHour = memory.createdAt.hour; + if (newHour > hour) { + memoriesWithDates.insert(date + 1, memory); + idx = date + 1; + } else { + memoriesWithDates.insert(date + 2, memory); + idx = date + 2; + } + } else { + memoriesWithDates.add(memory.createdAt); + memoriesWithDates.add(memory); + idx = memoriesWithDates.length - 1; + } notifyListeners(); + return idx; } void updateMemory(ServerMemory memory, [int? index]) { From f945d41a24e83a34686d2b3c38c06f4026c4006f Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:49:40 +0530 Subject: [PATCH 47/88] minor ui improvements --- app/lib/pages/chat/page.dart | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/app/lib/pages/chat/page.dart b/app/lib/pages/chat/page.dart index eb77b2d9e..e7f0f538f 100644 --- a/app/lib/pages/chat/page.dart +++ b/app/lib/pages/chat/page.dart @@ -120,14 +120,16 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { ], ) : (provider.messages.isEmpty) - ? Padding( - padding: const EdgeInsets.only(top: 16.0), - child: Text( - connectivityProvider.isConnected - ? 'No messages yet!\nWhy don\'t you start a conversation?' - : 'Please check your internet connection and try again', - textAlign: TextAlign.center, - style: const TextStyle(color: Colors.white)), + ? Center( + child: Padding( + padding: const EdgeInsets.only(bottom: 32.0), + child: Text( + connectivityProvider.isConnected + ? 'No messages yet!\nWhy don\'t you start a conversation?' + : 'Please check your internet connection and try again', + textAlign: TextAlign.center, + style: const TextStyle(color: Colors.white)), + ), ) : ListView.builder( shrinkWrap: true, From 405252cc6de1ca2265ca53c7737890f0c0f85641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 21 Sep 2024 07:20:14 +0700 Subject: [PATCH 48/88] Handle socket > channel ready exception --- app/lib/utils/pure_socket.dart | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/app/lib/utils/pure_socket.dart b/app/lib/utils/pure_socket.dart index 5660212e7..bf6adfbb5 100644 --- a/app/lib/utils/pure_socket.dart +++ b/app/lib/utils/pure_socket.dart @@ -1,5 +1,6 @@ import 'dart:async'; import 'dart:convert'; +import 'dart:io'; import 'dart:math'; import 'package:flutter/material.dart'; @@ -122,7 +123,19 @@ class PureSocket implements IPureSocket { } _status = PureSocketStatus.connecting; - await _channel?.ready; + dynamic err; + try { + await channel.ready; + } on SocketException catch (e) { + err = e; + } on WebSocketChannelException catch (e) { + err = e; + } + if (err != null) { + print("Error: $err"); + _status = PureSocketStatus.notConnected; + return false; + } _status = PureSocketStatus.connected; _retries = 0; From 4a297cbc1741b1a55baf803b7a16ec6b4d9c09b0 Mon Sep 17 00:00:00 2001 From: Salman Mian Date: Fri, 20 Sep 2024 17:21:42 -0700 Subject: [PATCH 49/88] Update Flash_device.md --- docs/_get_started/Flash_device.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/_get_started/Flash_device.md b/docs/_get_started/Flash_device.md index b2ee810ce..40f328c85 100644 --- a/docs/_get_started/Flash_device.md +++ b/docs/_get_started/Flash_device.md @@ -4,7 +4,7 @@ title: Update FRIEND Firmware nav_order: 3 --- # Video Tutorial -For a visual walkthrough of the flashing process, watch the [Updating Your FRIEND](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/docs/images/updating_your_friend.mov) video. +For a visual walkthrough of the flashing process, watch the [Updating Your FRIEND](https://github.com/BasedHardware/omi/blob/main/docs/images/updating_your_friend.mov) video. # Flashing FRIEND Firmware` @@ -17,8 +17,8 @@ This guide will walk you through the process of flashing the latest firmware ont 2. Find the latest firmware release and bootloader, then download the corresponding `.uf2` files. Or download these files - - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/bootloader/bootloader0.9.0.uf2) - - **Firmware:** [firmware1.0.4.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/firmware1.0.4.uf2) + - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.3-firmware/update-xiao_nrf52840_ble_sense_bootloader-0.9.0_nosd.uf2) + - **Firmware:** [firmware1.0.4.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.4-firmware/friend-xiao_nrf52840_ble_sense-1.0.4.uf2) ## Putting FRIEND into DFU Mode @@ -34,11 +34,11 @@ Or download these files 1. Locate the `.uf2` files you downloaded earlier. 2. Drag and drop the bootloader `.uf2` file onto the `/Volumes/XIAO-SENSE` drive: - - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/bootloader/bootloader0.9.0.uf2) + - **Bootloader:** [bootloader0.9.0.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.3-firmware/update-xiao_nrf52840_ble_sense_bootloader-0.9.0_nosd.uf2) 3. The device will automatically eject itself once the bootloader flashing process is complete. 4. After the device forcibly ejects, set the FRIEND device back into DFU mode by double-tapping the reset button. 5. Drag and drop the FRIEND firmware file onto the `/Volumes/XIAO-SENSE` drive: - - **Firmware:** [firmware1.0.4.uf2](https://github.com/ebowwa/omi/blob/firmware-flashing-readme/devices/Friend/firmware/firmware1.0.4.uf2) + - **Firmware:** [firmware1.0.4.uf2](https://github.com/BasedHardware/omi/releases/download/v1.0.4-firmware/friend-xiao_nrf52840_ble_sense-1.0.4.uf2) ## Congratulations! From 3c087591daddcea4ee98a1a2ae022dc71c6cc680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 21 Sep 2024 08:32:51 +0700 Subject: [PATCH 50/88] Bump version to 1.0.39+135 --- app/pubspec.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/pubspec.yaml b/app/pubspec.yaml index 4990f937a..d968f8318 100644 --- a/app/pubspec.yaml +++ b/app/pubspec.yaml @@ -3,7 +3,7 @@ description: A new Flutter project. publish_to: 'none' # Remove this line if you wish to publish to pub.dev -version: 1.0.38+134 +version: 1.0.39+135 environment: sdk: ">=3.0.0 <4.0.0" From ef9e070edcee9aa98673e09aff6572b2cc55c2e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 21 Sep 2024 09:30:56 +0700 Subject: [PATCH 51/88] Log stt service in websocket router --- backend/routers/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index bf10897f3..446e0885e 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -89,7 +89,7 @@ async def _websocket_util( channels: int = 1, include_speech_profile: bool = True, new_memory_watch: bool = False, stt_service: STTService = STTService.deepgram, ): - print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch) + print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, stt_service) if stt_service == STTService.soniox and ( sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages): From 43b57e89de15f56fb0188ad0f1113a60ea62052b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 21 Sep 2024 10:54:47 +0700 Subject: [PATCH 52/88] Increase modal timeout to 35m, transcribe soft timeout to 30m --- backend/main.py | 2 +- backend/routers/transcribe.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/main.py b/backend/main.py index ab21cdb0a..1b2362dcd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -54,7 +54,7 @@ memory=(512, 1024), cpu=2, allow_concurrent_inputs=10, - timeout=60 * 10, + timeout=60 * 35, ) @asgi_app() def api(): diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 446e0885e..aab9c789e 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -1,4 +1,7 @@ import threading +import asyncio +import time +from typing import List import uuid from datetime import datetime, timezone from enum import Enum @@ -126,8 +129,8 @@ async def _websocket_util( speech_profile_stream_id = 2 loop = asyncio.get_event_loop() - # Soft timeout - timeout_seconds = 420 # 7m ~ MODAL_TIME_OUT - 3m + # Soft timeout, should < MODAL_TIME_OUT - 3m + timeout_seconds = 1800 # 30m started_at = time.time() def stream_transcript(segments, stream_id): From a526cce390b2f9d77f947d175292f94f1297c461 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 21 Sep 2024 11:21:52 +0700 Subject: [PATCH 53/88] Optmize process updates patch by basic fields using only --- backend/database/processing_memories.py | 14 ++++++++++++++ backend/models/processing_memory.py | 9 ++++++++- backend/utils/processing_memories.py | 8 ++++---- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/backend/database/processing_memories.py b/backend/database/processing_memories.py index 7d839e2f1..7a2b9d932 100644 --- a/backend/database/processing_memories.py +++ b/backend/database/processing_memories.py @@ -32,6 +32,20 @@ def get_processing_memories_by_id(uid, processing_memory_ids): memories.append(doc.to_dict()) return memories +def get_basic_processing_memories_by_id(uid, processing_memory_ids): + user_ref = db.collection('users').document(uid) + memories_ref = user_ref.collection('processing_memories') + + doc_refs = [memories_ref.document(str(processing_memory_id)) for processing_memory_id in processing_memory_ids] + docs = db.get_all(doc_refs, field_paths=["id", "created_at", "geolocation", "emotional_feedback", "timer_start"],) + + memories = [] + for doc in docs: + if doc.exists: + memories.append(doc.to_dict()) + return memories + + def update_processing_memory_segments(uid: str, id: str, segments: List[dict]): user_ref = db.collection('users').document(uid) memory_ref = user_ref.collection('processing_memories').document(id) diff --git a/backend/models/processing_memory.py b/backend/models/processing_memory.py index 7eb1f6656..c5ac29f08 100644 --- a/backend/models/processing_memory.py +++ b/backend/models/processing_memory.py @@ -23,6 +23,13 @@ class ProcessingMemory(BaseModel): memory_id: Optional[str] = None message_ids: List[str] = [] +class BasicProcessingMemory(BaseModel): + id: str + timer_start: float + created_at: datetime + geolocation: Optional[Geolocation] = None + emotional_feedback: Optional[bool] = False + class UpdateProcessingMemory(BaseModel): id: Optional[str] = None @@ -31,4 +38,4 @@ class UpdateProcessingMemory(BaseModel): class UpdateProcessingMemoryResponse(BaseModel): - result: ProcessingMemory + result: BasicProcessingMemory diff --git a/backend/utils/processing_memories.py b/backend/utils/processing_memories.py index d31874c41..5ba994d67 100644 --- a/backend/utils/processing_memories.py +++ b/backend/utils/processing_memories.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime, timezone -from models.processing_memory import ProcessingMemory, UpdateProcessingMemory +from models.processing_memory import ProcessingMemory, UpdateProcessingMemory, BasicProcessingMemory from models.memory import CreateMemory, PostProcessingModel, PostProcessingStatus, MemoryPostProcessing, TranscriptSegment from utils.memories.process_memory import process_memory from utils.memories.location import get_google_maps_location @@ -55,13 +55,13 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str return (memory, messages, processing_memory) -def update_basic_processing_memory(uid: str, update_processing_memory: UpdateProcessingMemory,) -> ProcessingMemory: +def update_basic_processing_memory(uid: str, update_processing_memory: UpdateProcessingMemory,) -> BasicProcessingMemory: # Fetch new - processing_memories = processing_memories_db.get_processing_memories_by_id(uid, [update_processing_memory.id]) + processing_memories = processing_memories_db.get_basic_processing_memories_by_id(uid, [update_processing_memory.id]) if len(processing_memories) == 0: print("processing memory is not found") return - processing_memory = ProcessingMemory(**processing_memories[0]) + processing_memory = BasicProcessingMemory(**processing_memories[0]) # geolocation if update_processing_memory.geolocation: From 217c083919224ff0aaace75bc2aa58f929e76d3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 21 Sep 2024 11:42:31 +0700 Subject: [PATCH 54/88] Simplify the processing memory query by get one doc directly --- backend/database/processing_memories.py | 16 +++------------- backend/routers/transcribe.py | 6 +++--- backend/utils/processing_memories.py | 6 +++--- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/backend/database/processing_memories.py b/backend/database/processing_memories.py index 7a2b9d932..db1a4bd76 100644 --- a/backend/database/processing_memories.py +++ b/backend/database/processing_memories.py @@ -32,19 +32,9 @@ def get_processing_memories_by_id(uid, processing_memory_ids): memories.append(doc.to_dict()) return memories -def get_basic_processing_memories_by_id(uid, processing_memory_ids): - user_ref = db.collection('users').document(uid) - memories_ref = user_ref.collection('processing_memories') - - doc_refs = [memories_ref.document(str(processing_memory_id)) for processing_memory_id in processing_memory_ids] - docs = db.get_all(doc_refs, field_paths=["id", "created_at", "geolocation", "emotional_feedback", "timer_start"],) - - memories = [] - for doc in docs: - if doc.exists: - memories.append(doc.to_dict()) - return memories - +def get_processing_memory_by_id(uid, processing_memory_id): + memory_ref = db.collection('users').document(uid).collection('processing_memories').document(processing_memory_id) + return memory_ref.get().to_dict() def update_processing_memory_segments(uid: str, id: str, segments: List[dict]): user_ref = db.collection('users').document(uid) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index aab9c789e..b0de417c2 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -462,11 +462,11 @@ async def _create_memory(): await _create_processing_memory() else: # or ensure synced processing transcript - processing_memories = processing_memories_db.get_processing_memories_by_id(uid, [processing_memory.id]) - if len(processing_memories) == 0: + processing_memory_data = processing_memories_db.get_processing_memory_by_id(uid, processing_memory.id) + if not processing_memory_data: print("processing memory is not found") return - processing_memory = ProcessingMemory(**processing_memories[0]) + processing_memory = ProcessingMemory(**processing_memory_data) processing_memory_synced = len(memory_transcript_segements) processing_memory.transcript_segments = memory_transcript_segements[:processing_memory_synced] diff --git a/backend/utils/processing_memories.py b/backend/utils/processing_memories.py index 5ba994d67..4cff65121 100644 --- a/backend/utils/processing_memories.py +++ b/backend/utils/processing_memories.py @@ -57,11 +57,11 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str def update_basic_processing_memory(uid: str, update_processing_memory: UpdateProcessingMemory,) -> BasicProcessingMemory: # Fetch new - processing_memories = processing_memories_db.get_basic_processing_memories_by_id(uid, [update_processing_memory.id]) - if len(processing_memories) == 0: + processing_memory = processing_memories_db.get_processing_memory_by_id(uid, update_processing_memory.id) + if not processing_memory: print("processing memory is not found") return - processing_memory = BasicProcessingMemory(**processing_memories[0]) + processing_memory = BasicProcessingMemory(**processing_memory) # geolocation if update_processing_memory.geolocation: From a1cc30b450d3e84e1be074772b43baad193e585b Mon Sep 17 00:00:00 2001 From: Abdul Rahman ArM <39548998+invarrow@users.noreply.github.com> Date: Sat, 21 Sep 2024 11:47:13 +0530 Subject: [PATCH 55/88] Update README.md fixed fastapi run command in readme --- plugins/example/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/example/README.md b/plugins/example/README.md index dbff0f747..1379e28b9 100644 --- a/plugins/example/README.md +++ b/plugins/example/README.md @@ -37,7 +37,7 @@ Fill in the values for each variable as needed. 2. Install the requirements.txt by running: `pip install -r requirements.txt` -3. Run the project: `fastapi run dev` This will start the FastAPI application. +3. Run the project: `fastapi run` or `fastapi dev` This will start the FastAPI application. ## Project Structure @@ -64,4 +64,4 @@ Fill in the values for each variable as needed. - `/notion-crm`: Store memory in Notion database - `/news-checker`: Check news based on conversation transcript -For more details on how to use these endpoints, refer to the code documentation or contact the project maintainer. \ No newline at end of file +For more details on how to use these endpoints, refer to the code documentation or contact the project maintainer. From 7ed77f117654b989acf66b39576cf81560d91649 Mon Sep 17 00:00:00 2001 From: Abdul Rahman ArM <39548998+invarrow@users.noreply.github.com> Date: Sat, 21 Sep 2024 11:55:18 +0530 Subject: [PATCH 56/88] Delete docs/_developer/Plugins.md redundant page: Plugins page already exists --- docs/_developer/Plugins.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 docs/_developer/Plugins.md diff --git a/docs/_developer/Plugins.md b/docs/_developer/Plugins.md deleted file mode 100644 index 11412bf9c..000000000 --- a/docs/_developer/Plugins.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -layout: default -title: Plugins -nav_order: 5 ---- - -We've moved! Go to https://basedhardware.com/plugins From 8be59e497641a8ad6ec8780cd200a54201414d7a Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Sat, 21 Sep 2024 15:25:58 +0530 Subject: [PATCH 57/88] show delete and refresh on scroll --- app/lib/pages/chat/page.dart | 91 +++++++++++++++---- .../chat/widgets/animated_mini_banner.dart | 7 +- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/app/lib/pages/chat/page.dart b/app/lib/pages/chat/page.dart index e7f0f538f..5608c3469 100644 --- a/app/lib/pages/chat/page.dart +++ b/app/lib/pages/chat/page.dart @@ -2,6 +2,7 @@ import 'dart:io'; import 'package:collection/collection.dart'; import 'package:flutter/material.dart'; +import 'package:flutter/rendering.dart'; import 'package:flutter/scheduler.dart'; import 'package:friend_private/backend/http/api/messages.dart'; import 'package:friend_private/backend/preferences.dart'; @@ -30,7 +31,10 @@ class ChatPage extends StatefulWidget { class ChatPageState extends State with AutomaticKeepAliveClientMixin { TextEditingController textController = TextEditingController(); - ScrollController scrollController = ScrollController(); + late ScrollController scrollController; + + bool _showDeleteOption = false; + bool isScrollingDown = false; var prefs = SharedPreferencesUtil(); late List plugins; @@ -50,11 +54,28 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { @override void initState() { plugins = prefs.pluginsList; + scrollController = ScrollController(); + scrollController.addListener(() { + if (scrollController.position.userScrollDirection == ScrollDirection.reverse) { + if (!isScrollingDown) { + isScrollingDown = true; + _showDeleteOption = true; + setState(() {}); + } + } + + if (scrollController.position.userScrollDirection == ScrollDirection.forward) { + if (isScrollingDown) { + isScrollingDown = false; + _showDeleteOption = false; + setState(() {}); + } + } + }); SchedulerBinding.instance.addPostFrameCallback((_) async { - // await context.read().refreshMessages(); scrollToBottom(); }); - // _initDailySummary(); + ; super.initState(); } @@ -73,20 +94,58 @@ class ChatPageState extends State with AutomaticKeepAliveClientMixin { builder: (context, provider, connectivityProvider, child) { return Scaffold( backgroundColor: Theme.of(context).colorScheme.primary, - appBar: AnimatedMiniBanner( - showAppBar: provider.isLoadingMessages, - child: Container( - width: double.infinity, - height: 10, - color: Colors.green, - child: const Center( - child: Text( - 'Syncing messages with server...', - style: TextStyle(color: Colors.white, fontSize: 14), + appBar: provider.isLoadingMessages + ? AnimatedMiniBanner( + showAppBar: provider.isLoadingMessages, + child: Container( + width: double.infinity, + height: 10, + color: Colors.green, + child: const Center( + child: Text( + 'Syncing messages with server...', + style: TextStyle(color: Colors.white, fontSize: 14), + ), + ), + ), + ) + : AnimatedMiniBanner( + showAppBar: _showDeleteOption, + height: 80, + child: Container( + width: double.infinity, + height: 40, + color: Theme.of(context).primaryColor, + child: Row( + children: [ + const SizedBox(width: 20), + InkWell( + onTap: () async { + await context.read().refreshMessages(); + }, + child: const Text( + 'Refresh Chat', + style: TextStyle(color: Colors.white, fontSize: 14), + ), + ), + const Spacer(), + InkWell( + onTap: () async { + setState(() { + _showDeleteOption = false; + }); + await context.read().clearChat(); + }, + child: const Text( + 'Clear Chat', + style: TextStyle(color: Colors.white, fontSize: 14), + ), + ), + const SizedBox(width: 20), + ], + ), + ), ), - ), - ), - ), body: Stack( children: [ Align( diff --git a/app/lib/pages/chat/widgets/animated_mini_banner.dart b/app/lib/pages/chat/widgets/animated_mini_banner.dart index ccff5e758..b7289f3b1 100644 --- a/app/lib/pages/chat/widgets/animated_mini_banner.dart +++ b/app/lib/pages/chat/widgets/animated_mini_banner.dart @@ -1,20 +1,21 @@ import 'package:flutter/material.dart'; class AnimatedMiniBanner extends StatelessWidget implements PreferredSizeWidget { - const AnimatedMiniBanner({super.key, required this.showAppBar, required this.child}); + const AnimatedMiniBanner({super.key, required this.showAppBar, required this.child, this.height = 30}); final bool showAppBar; final Widget child; + final double height; @override Widget build(BuildContext context) { return AnimatedContainer( height: showAppBar ? kToolbarHeight : 0, - duration: const Duration(milliseconds: 400), + duration: const Duration(milliseconds: 300), child: child, ); } @override - Size get preferredSize => const Size.fromHeight(30); + Size get preferredSize => Size.fromHeight(height); } From 47e8f15b0b51cd00b0ac5d379d2a32779f40662e Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Sat, 21 Sep 2024 15:26:24 +0530 Subject: [PATCH 58/88] soft delete messages --- backend/database/chat.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index aa27609c8..309e7ba11 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -60,6 +60,10 @@ def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: b # Fetch messages and collect memory IDs for doc in messages_ref.stream(): message = doc.to_dict() + + if message.get('deleted') is True: + continue + messages.append(message) memories_id.update(message.get('memories_id', [])) @@ -92,16 +96,18 @@ async def batch_delete_messages(parent_doc_ref, batch_size=450): docs = messages_ref.limit(batch_size).stream() docs_list = list(docs) if not docs_list: + print("No more messages to delete") break batch = db.batch() for doc in docs_list: - batch.delete(doc.reference) + batch.update(doc.reference, {'deleted': True}) batch.commit() -async def clear_chat( uid: str): +async def clear_chat(uid: str): try: user_ref = db.collection('users').document(uid) + print(f"Deleting messages for user: {uid}") if not user_ref.get().exists: raise HTTPException(status_code=404, detail="User not found") await batch_delete_messages(user_ref) From 6ddd63ac6380e80b51feca1ac699fa7f5f791a09 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Sat, 21 Sep 2024 23:40:08 +0530 Subject: [PATCH 59/88] Improve permission handling --- .../permissions/permissions_widget.dart | 70 +++++++++---------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/app/lib/pages/onboarding/permissions/permissions_widget.dart b/app/lib/pages/onboarding/permissions/permissions_widget.dart index 019f2ac35..d82999233 100644 --- a/app/lib/pages/onboarding/permissions/permissions_widget.dart +++ b/app/lib/pages/onboarding/permissions/permissions_widget.dart @@ -166,46 +166,44 @@ class _PermissionsWidgetState extends State { await provider.askForBackgroundPermissions(); } } - await Permission.notification.request().then((value) async { - if (value.isGranted) { - provider.updateNotificationPermission(true); - } - if (await Permission.location.serviceStatus.isEnabled) { - await Permission.locationWhenInUse.request().then((value) async { - if (value.isGranted) { - await Permission.locationAlways.request().then((value) async { - print('Location permission: ${value.isGranted}'); + await Permission.notification.request().then( + (value) async { + if (value.isGranted) { + provider.updateNotificationPermission(true); + } + if (await Permission.location.serviceStatus.isEnabled) { + await Permission.locationWhenInUse.request().then( + (value) async { if (value.isGranted) { - provider.setLoading(false); - if (Platform.isAndroid) { - widget.goNext(); - } - provider.updateLocationPermission(true); - } - value.isGranted - ? () { + await Permission.locationAlways.request().then( + (value) async { + if (value.isGranted) { + provider.updateLocationPermission(true); widget.goNext(); provider.setLoading(false); + } else { + Future.delayed(const Duration(milliseconds: 2500), () async { + if (await Permission.locationAlways.status.isGranted) { + provider.updateLocationPermission(true); + } + widget.goNext(); + provider.setLoading(false); + }); } - : Future.delayed(const Duration(milliseconds: 2500), () async { - if (await Permission.locationAlways.status.isGranted) { - provider.updateLocationPermission(true); - } - print('Location permission222222: ${value.isGranted}'); - widget.goNext(); - provider.setLoading(false); - }); - }); - } else { - widget.goNext(); - provider.setLoading(false); - } - }); - } else { - widget.goNext(); - provider.setLoading(false); - } - }); + }, + ); + } else { + widget.goNext(); + provider.setLoading(false); + } + }, + ); + } else { + widget.goNext(); + provider.setLoading(false); + } + }, + ); }, child: const Text( 'Continue', From 60b9be1ab30faef01177476228002a8bea905702 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 06:30:26 +0700 Subject: [PATCH 60/88] Preparing status, fix recording issue after onboarding with no devices --- app/lib/backend/preferences.dart | 2 ++ app/lib/pages/memories/widgets/processing_capture.dart | 6 +++++- app/lib/providers/capture_provider.dart | 9 +++++++-- app/lib/providers/device_provider.dart | 7 +++++-- app/lib/providers/onboarding_provider.dart | 2 -- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/app/lib/backend/preferences.dart b/app/lib/backend/preferences.dart index 3f5193619..c87b718f0 100644 --- a/app/lib/backend/preferences.dart +++ b/app/lib/backend/preferences.dart @@ -47,6 +47,8 @@ class SharedPreferencesUtil { set deviceCodec(BleAudioCodec value) => saveString('deviceCodec', mapCodecToName(value)); + Future setDeviceCodec(BleAudioCodec value) => saveString('deviceCodec', mapCodecToName(value)); + BleAudioCodec get deviceCodec => mapNameToCodec(getString('deviceCodec') ?? ''); String get openAIApiKey => getString('openaiApiKey') ?? ''; diff --git a/app/lib/pages/memories/widgets/processing_capture.dart b/app/lib/pages/memories/widgets/processing_capture.dart index 68a9a9a83..9526e9027 100644 --- a/app/lib/pages/memories/widgets/processing_capture.dart +++ b/app/lib/pages/memories/widgets/processing_capture.dart @@ -26,6 +26,7 @@ class MemoryCaptureWidget extends StatefulWidget { } class _MemoryCaptureWidgetState extends State { + @override Widget build(BuildContext context) { return Consumer3( @@ -120,9 +121,12 @@ class _MemoryCaptureWidgetState extends State { } else if (captureProvider.memoryCreating) { stateText = "Processing"; isConnected = deviceProvider.connectedDevice != null; - } else if (deviceProvider.connectedDevice != null || captureProvider.recordingState == RecordingState.record) { + } else if (captureProvider.recordingDeviceServiceReady && captureProvider.transcriptServiceReady) { stateText = "Listening"; isConnected = true; + } else if (captureProvider.recordingDeviceServiceReady || captureProvider.transcriptServiceReady) { + stateText = "Preparing"; + isConnected = true; } var isUsingPhoneMic = captureProvider.recordingState == RecordingState.record || diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index 50c74f3b8..cabb25cf3 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -71,6 +71,10 @@ class CaptureProvider extends ChangeNotifier RecordingState recordingState = RecordingState.stop; + bool get transcriptServiceReady => _socket?.state == SocketServiceState.connected; + + bool get recordingDeviceServiceReady => _recordingDevice != null || recordingState == RecordingState.record; + // ----------------------- // Memory creation variables double? streamStartedAtSecond; @@ -478,6 +482,7 @@ class CaptureProvider extends ChangeNotifier _cleanupCurrentState(); await _handleMemoryCreation(restartBytesProcessing); + await _recheckCodecChange(); await _ensureSocketConnection(force: true); await startOpenGlass(); @@ -560,12 +565,12 @@ class CaptureProvider extends ChangeNotifier return connection.hasPhotoStreamingCharacteristic(); } - Future _checkCodecChange() async { + Future _recheckCodecChange() async { if (_recordingDevice != null) { BleAudioCodec newCodec = await _getAudioCodec(_recordingDevice!.id); if (SharedPreferencesUtil().deviceCodec != newCodec) { debugPrint('Device codec changed from ${SharedPreferencesUtil().deviceCodec} to $newCodec'); - SharedPreferencesUtil().deviceCodec = newCodec; + await SharedPreferencesUtil().setDeviceCodec(newCodec); return true; } } diff --git a/app/lib/providers/device_provider.dart b/app/lib/providers/device_provider.dart index 936594f8b..ce064d659 100644 --- a/app/lib/providers/device_provider.dart +++ b/app/lib/providers/device_provider.dart @@ -23,6 +23,10 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption Timer? _disconnectNotificationTimer; + DeviceProvider() { + ServiceManager.instance().device.subscribe(this, this); + } + void setProviders(CaptureProvider provider) { captureProvider = provider; notifyListeners(); @@ -114,7 +118,6 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption } Future scanAndConnectToDevice() async { - ServiceManager.instance().device.subscribe(this, this); updateConnectingStatus(true); if (isConnected) { if (connectedDevice == null) { @@ -205,7 +208,7 @@ class DeviceProvider extends ChangeNotifier implements IDeviceServiceSubsciption setConnectedDevice(device); setIsConnected(true); updateConnectingStatus(false); - await captureProvider?.streamDeviceRecording(restartBytesProcessing: true, device: connectedDevice); + await captureProvider?.streamDeviceRecording(restartBytesProcessing: true, device: device); // initiateBleBatteryListener(); // The device is still disconnected for some reason if (connectedDevice != null) { diff --git a/app/lib/providers/onboarding_provider.dart b/app/lib/providers/onboarding_provider.dart index d7b90723e..7ddcd703f 100644 --- a/app/lib/providers/onboarding_provider.dart +++ b/app/lib/providers/onboarding_provider.dart @@ -179,7 +179,6 @@ class OnboardingProvider extends BaseProvider with MessageNotifierMixin implemen deviceProvider!.setConnectedDevice(cDevice); SharedPreferencesUtil().btDeviceStruct = cDevice; SharedPreferencesUtil().deviceName = cDevice.name; - SharedPreferencesUtil().deviceCodec = await _getAudioCodec(device.id); deviceProvider!.setIsConnected(true); } //TODO: should'nt update codec here, becaause then the prev connection codec and the current codec will @@ -197,7 +196,6 @@ class OnboardingProvider extends BaseProvider with MessageNotifierMixin implemen await Future.delayed(const Duration(seconds: 2)); SharedPreferencesUtil().btDeviceStruct = connectedDevice!; SharedPreferencesUtil().deviceName = connectedDevice.name; - SharedPreferencesUtil().deviceCodec = await _getAudioCodec(device.id); foundDevicesMap.clear(); deviceList.clear(); if (isFromOnboarding) { From 7a810dd0503eb645c69502270b4c8adcf0d1ee79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 06:42:37 +0700 Subject: [PATCH 61/88] Ensure stranscript service ready state --- app/lib/providers/capture_provider.dart | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/lib/providers/capture_provider.dart b/app/lib/providers/capture_provider.dart index cabb25cf3..6a9acd975 100644 --- a/app/lib/providers/capture_provider.dart +++ b/app/lib/providers/capture_provider.dart @@ -71,7 +71,8 @@ class CaptureProvider extends ChangeNotifier RecordingState recordingState = RecordingState.stop; - bool get transcriptServiceReady => _socket?.state == SocketServiceState.connected; + bool _transcriptServiceReady = false; + bool get transcriptServiceReady => _transcriptServiceReady; bool get recordingDeviceServiceReady => _recordingDevice != null || recordingState == RecordingState.record; @@ -371,6 +372,7 @@ class CaptureProvider extends ChangeNotifier throw Exception("Can not create new memory socket"); } _socket?.subscribe(this, this); + _transcriptServiceReady = true; if (segments.isNotEmpty) { // means that it was a reconnection, so we need to reset @@ -702,6 +704,7 @@ class CaptureProvider extends ChangeNotifier @override void onClosed() { + _transcriptServiceReady = false; debugPrint('[Provider] Socket is closed'); _clean(); @@ -733,6 +736,7 @@ class CaptureProvider extends ChangeNotifier @override void onError(Object err) { + _transcriptServiceReady = false; debugPrint('err: $err'); notifyListeners(); From 06f2e523d5f28514111aa235747ed96db70f1c18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 06:44:41 +0700 Subject: [PATCH 62/88] Bump version to 1.0.39+136 --- app/pubspec.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/pubspec.yaml b/app/pubspec.yaml index d968f8318..d10b6cedb 100644 --- a/app/pubspec.yaml +++ b/app/pubspec.yaml @@ -3,7 +3,7 @@ description: A new Flutter project. publish_to: 'none' # Remove this line if you wish to publish to pub.dev -version: 1.0.39+135 +version: 1.0.39+136 environment: sdk: ">=3.0.0 <4.0.0" From 5ff87f1da654b5676cdf3d622cc112f318b0a3a5 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sat, 21 Sep 2024 20:47:33 -0700 Subject: [PATCH 63/88] deepgram opus uses pcm16 instead --- backend/routers/transcribe.py | 56 +++++++++++------------------------ 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 446e0885e..0c824fb2f 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -89,7 +89,8 @@ async def _websocket_util( channels: int = 1, include_speech_profile: bool = True, new_memory_watch: bool = False, stt_service: STTService = STTService.deepgram, ): - print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, stt_service) + print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, + stt_service) if stt_service == STTService.soniox and ( sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages): @@ -168,7 +169,6 @@ def stream_audio(audio_buffer): soniox_socket = None speechmatics_socket = None - speechmatics_socket2 = None deepgram_socket = None deepgram_socket2 = None @@ -178,44 +178,33 @@ def stream_audio(audio_buffer): # audio_buffer = None duration = 0 try: - # Soniox - if stt_service == STTService.deepgram: - if language == 'en' and codec == 'opus' and include_speech_profile: - second_per_frame: float = 0.01 - speech_profile = get_user_speech_profile(uid) - duration = len(speech_profile) * second_per_frame - print('speech_profile', len(speech_profile), duration) - if duration: - duration += 10 - else: - speech_profile, duration = [], 0 + file_path, duration = None, 0 + if language == 'en' and (codec == 'opus' or codec == 'pcm16') and include_speech_profile: + file_path = get_profile_audio_if_exists(uid) + duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0 + # Deepgram + if stt_service == STTService.deepgram: + deepgram_codec_value = 'pcm16' if codec == 'opus' else codec deepgram_socket = await process_audio_dg( - stream_transcript, memory_stream_id, language, sample_rate, codec, channels, preseconds=duration + stream_transcript, memory_stream_id, language, sample_rate, deepgram_codec_value, channels, + preseconds=duration ) if duration: deepgram_socket2 = await process_audio_dg( - stream_transcript, speech_profile_stream_id, language, sample_rate, codec, channels + stream_transcript, speech_profile_stream_id, language, sample_rate, deepgram_codec_value, channels ) - await send_initial_file(speech_profile, deepgram_socket) + await send_initial_file_path(file_path, deepgram_socket) elif stt_service == STTService.soniox: soniox_socket = await process_audio_soniox( stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None ) elif stt_service == STTService.speechmatics: - file_path = None - if language == 'en' and codec == 'opus' and include_speech_profile: - file_path = get_profile_audio_if_exists(uid) - duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0 - speechmatics_socket = await process_audio_speechmatics( stream_transcript, speech_profile_stream_id, language, preseconds=duration ) if duration: - # speechmatics_socket2 = await process_audio_speechmatics( - # stream_transcript, speech_profile_stream_id, language, preseconds=duration - # ) await send_initial_file_path(file_path, speechmatics_socket) print('speech_profile speechmatics duration', duration) @@ -249,23 +238,14 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock while websocket_active: data = await websocket.receive_bytes() # audio_buffer.extend(data) + if codec == 'opus' and sample_rate == 16000: + data = decoder.decode(bytes(data), frame_size=160) if soniox_socket is not None: - decoded_opus = decoder.decode(bytes(data), frame_size=160) - await soniox_socket.send(decoded_opus) + await soniox_socket.send(data) if speechmatics_socket1 is not None: - decoded_opus = decoder.decode(bytes(data), frame_size=160) - await speechmatics_socket1.send(decoded_opus) - - # elapsed_seconds = time.time() - timer_start - # if elapsed_seconds > duration or not dg_socket2: - # if speechmatics_socket2: - # print('Killing socket2 speechmatics') - # speechmatics_socket2.close() - # speechmatics_socket2 = None - # else: - # speechmatics_socket2.send(decoded_opus) + await speechmatics_socket1.send(data) if deepgram_socket is not None: elapsed_seconds = time.time() - timer_start @@ -298,8 +278,6 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock await soniox_socket.close() if speechmatics_socket: await speechmatics_socket.close() - if speechmatics_socket2: - await speechmatics_socket2.close() # heart beat async def send_heartbeat(): From 96d65e46cfbb84426a419773d2e272c70dba08ec Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sat, 21 Sep 2024 20:47:56 -0700 Subject: [PATCH 64/88] further script for computing transcripts with each model + WER initial results --- .../stt/k_compare_transcripts_performance.py | 111 ++++++++++++------ 1 file changed, 76 insertions(+), 35 deletions(-) diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index bdb57c2d8..5ffeabb35 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -19,11 +19,14 @@ load_dotenv('../../.dev.env') os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS') - firebase_admin.initialize_app() from models.transcript_segment import TranscriptSegment from utils.stt.streaming import process_audio_dg, process_audio_soniox, process_audio_speechmatics +from groq import Groq + +from utils.other.storage import upload_postprocessing_audio +from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing def store_model_result(memory_id: str, model: str, result: List[Dict]): @@ -43,25 +46,24 @@ def store_model_result(memory_id: str, model: str, result: List[Dict]): json.dump(results, f) -def add_model_result_segments(memory_id: str, model: str, result: List[Dict]): - file_path = 'results.json' - if os.path.exists(file_path): - with open(file_path, 'r') as f: - results = json.load(f) - else: - results = {} - - if memory_id not in results: - results[memory_id] = {} - - if model not in results[memory_id]: - results[memory_id][model] = [] +def add_model_result_segments(model: str, new_segments: List[Dict], result: Dict): + segments = [TranscriptSegment(**s) for s in result[model]] + new_segments = [TranscriptSegment(**s) for s in new_segments] + segments = TranscriptSegment.combine_segments(segments, new_segments) + result[model] = [s.dict() for s in segments] - segments = [TranscriptSegment(**s) for s in results[memory_id][model]] - new_segments = [TranscriptSegment(**s) for s in result] - segments = TranscriptSegment.combine_segments(segments, new_segments) - store_model_result(memory_id, model, [s.dict() for s in segments]) +def execute_groq(file_path: str): + client = Groq(api_key=os.getenv('GROQ_API_KEY')) + with open(file_path, "rb") as file: + transcription = client.audio.transcriptions.create( + file=(file_path, file.read()), + model="whisper-large-v3", + response_format="text", + language="en", + temperature=0.0 + ) + return str(transcription) async def process_memories_audio_files(): @@ -78,45 +80,84 @@ async def process_memories_audio_files(): # memories_data = get_memories_by_id(uid, memories_id) for file_path in memories: + aseg = AudioSegment.from_wav(file_path) memory_id = file_path.split('/')[-1].split('.')[0] - print(memory_id) + + if os.path.exists(f'results/{memory_id}.json'): + print('Already processed', memory_id) + continue + print('Started processing', memory_id, 'duration', aseg.duration_seconds) + result = { + 'deepgram': [], + 'soniox': [], + 'speechmatics': [] + } def stream_transcript_deepgram(new_segments, _): - print(new_segments) - add_model_result_segments(memory_id, 'deepgram', new_segments) + print('stream_transcript_deepgram', new_segments) + add_model_result_segments('deepgram', new_segments, result) def stream_transcript_soniox(new_segments, _): - print(new_segments) - add_model_result_segments(memory_id, 'soniox', new_segments) + print('stream_transcript_soniox', new_segments) + add_model_result_segments('soniox', new_segments, result) def stream_transcript_speechmatics(new_segments, _): - print(new_segments) - add_model_result_segments(memory_id, 'speechmatics', new_segments) + print('stream_transcript_speechmatics', new_segments) + add_model_result_segments('speechmatics', new_segments, result) + groq_result: str = execute_groq(file_path) # source of truth + result['whisper-large-v3'] = groq_result + + # whisperx + signed_url = upload_postprocessing_audio(file_path) + words = fal_whisperx(signed_url) + fal_segments = fal_postprocessing(words, aseg.duration_seconds) + result['fal_whisperx'] = [s.dict() for s in fal_segments] + + # streaming models socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0) socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 'en', None) socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 'en', 0) duration = AudioSegment.from_wav(file_path).duration_seconds print('duration', duration) with open(file_path, "rb") as file: - while True: chunk = file.read(320) if not chunk: break - # print('Uploading', len(chunk)) - # TODO: Race conditions here? socket.send(bytes(chunk)) await socket_soniox.send(bytes(chunk)) await socket_speechmatics.send(bytes(chunk)) - await asyncio.sleep(0.0001) - print('File sent') - # - call for whisper-x - # - store in a json file and cache - await asyncio.sleep(duration) # TODO: await duration - break + await asyncio.sleep(0.001) + await asyncio.sleep(duration) + + os.makedirs('results', exist_ok=True) + with open(f'results/{memory_id}.json', 'w') as f: + json.dump(result, f, indent=2) break +from jiwer import wer + + +def compute_wer(): + dir = 'results/' + for file in os.listdir(dir): + if not file.endswith('.json'): + continue + with open(f'{dir}{file}', 'r') as f: + result = json.load(f) + source = str(result['whisper-large-v3']).strip().lower().replace(' ', ' ') + print(file) + for model, segments in result.items(): + if model == 'whisper-large-v3': # TODO: words vs each other + continue + segments_str = ' '.join([s['text'] for s in segments]).strip().lower().replace(' ', ' ') + value = wer(source, segments_str) + print(f'{model} WER: {value}') + print('-----------------------------------------') + + if __name__ == '__main__': - asyncio.run(process_memories_audio_files()) + # asyncio.run(process_memories_audio_files()) + compute_wer() From 752ef4883cc6e9e8096b399ac908c7422bad4f59 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Sun, 22 Sep 2024 11:38:51 +0530 Subject: [PATCH 65/88] improve chat clear logic and make suggested changes --- backend/database/chat.py | 30 +++++++++++++++++++++++------- backend/routers/chat.py | 12 ++++++++---- backend/utils/chat/chat.py | 6 ++++++ 3 files changed, 37 insertions(+), 11 deletions(-) create mode 100644 backend/utils/chat/chat.py diff --git a/backend/database/chat.py b/backend/database/chat.py index 309e7ba11..0df869f85 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -89,27 +89,43 @@ def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: b return messages -async def batch_delete_messages(parent_doc_ref, batch_size=450): - # batch size is 450 because firebase can perform upto 500 operations in a batch +def batch_delete_messages(parent_doc_ref, batch_size=450): messages_ref = parent_doc_ref.collection('messages') + last_doc = None # For pagination + while True: - docs = messages_ref.limit(batch_size).stream() + if last_doc: + docs = messages_ref.limit(batch_size).start_after(last_doc).stream() + else: + docs = messages_ref.limit(batch_size).stream() + docs_list = list(docs) + if not docs_list: print("No more messages to delete") break + batch = db.batch() + for doc in docs_list: batch.update(doc.reference, {'deleted': True}) + batch.commit() + if len(docs_list) < batch_size: + print("Processed all messages") + break + + last_doc = docs_list[-1] + -async def clear_chat(uid: str): +def clear_chat(uid: str): try: user_ref = db.collection('users').document(uid) print(f"Deleting messages for user: {uid}") if not user_ref.get().exists: - raise HTTPException(status_code=404, detail="User not found") - await batch_delete_messages(user_ref) + return {"message": "User not found"} + batch_delete_messages(user_ref) + return None except Exception as e: - raise HTTPException(status_code=500, detail=f"Error deleting messages: {str(e)}") \ No newline at end of file + return {"message": str(e)} \ No newline at end of file diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 2836f1c19..089d560f2 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -2,10 +2,11 @@ from datetime import datetime, timezone from typing import List, Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException import database.chat as chat_db from models.chat import Message, SendMessageRequest, MessageSender +from utils.chat.chat import clear_user_chat_message from utils.llm import qa_rag, initial_chat_message from utils.other import endpoints as auth from utils.plugins import get_plugin_by_id @@ -29,7 +30,8 @@ def filter_messages(messages, plugin_id): def send_message( data: SendMessageRequest, plugin_id: Optional[str] = None, uid: str = Depends(auth.get_current_user_uid) ): - message = Message(id=str(uuid.uuid4()), text=data.text, created_at=datetime.now(timezone.utc), sender='human', type='text') + message = Message(id=str(uuid.uuid4()), text=data.text, created_at=datetime.now(timezone.utc), sender='human', + type='text') chat_db.add_message(uid, message.dict()) plugin = get_plugin_by_id(plugin_id) @@ -55,13 +57,15 @@ def send_message( ai_message.memories = memories if len(memories) < 5 else memories[:5] return ai_message + @router.delete('/v1/clear-chat', tags=['chat'], response_model=Message) def clear_chat(uid: str = Depends(auth.get_current_user_uid)): - chat_db.clear_chat(uid) + err = clear_user_chat_message(uid) + if err: + raise HTTPException(status_code=500, detail='Failed to clear chat') return initial_message_util(uid) - def initial_message_util(uid: str, plugin_id: Optional[str] = None): plugin = get_plugin_by_id(plugin_id) text = initial_chat_message(uid, plugin) diff --git a/backend/utils/chat/chat.py b/backend/utils/chat/chat.py new file mode 100644 index 000000000..8be3036e7 --- /dev/null +++ b/backend/utils/chat/chat.py @@ -0,0 +1,6 @@ +import database.chat as chat_db + + +def clear_user_chat_message(uid: str): + err = chat_db.clear_chat(uid) + return err From 646a2aecc5d6461c7f8be9367f546b90c0eb364f Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Sun, 22 Sep 2024 11:39:12 +0530 Subject: [PATCH 66/88] minor app chats refresh improvement --- app/lib/providers/message_provider.dart | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/lib/providers/message_provider.dart b/app/lib/providers/message_provider.dart index 838288ce4..ebb1c8d10 100644 --- a/app/lib/providers/message_provider.dart +++ b/app/lib/providers/message_provider.dart @@ -36,11 +36,15 @@ class MessageProvider extends ChangeNotifier { Future refreshMessages() async { setLoadingMessages(true); + if (SharedPreferencesUtil().cachedMessages.isNotEmpty) { + setHasCachedMessages(true); + } messages = await getMessagesFromServer(); if (messages.isEmpty) { messages = SharedPreferencesUtil().cachedMessages; } else { SharedPreferencesUtil().cachedMessages = messages; + setHasCachedMessages(true); } setLoadingMessages(false); notifyListeners(); From c942059fa249d3be99caef64bdffaea1b040c7f8 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Sun, 22 Sep 2024 11:42:13 +0530 Subject: [PATCH 67/88] remove exception import --- backend/database/chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index 0df869f85..dcf55bffe 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -2,7 +2,6 @@ from datetime import datetime, timezone from typing import Optional -from fastapi import HTTPException from google.cloud import firestore from models.chat import Message From 2229df9d95079b76591f9ae69d6da0dc46582952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 12:04:04 +0700 Subject: [PATCH 68/88] Align start seconds for transcript segment --- backend/routers/transcribe.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 8e7f20fdd..9047867e4 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -139,10 +139,21 @@ def stream_transcript(segments, stream_id): nonlocal processing_memory nonlocal processing_memory_synced nonlocal memory_transcript_segements + nonlocal segment_start if not segments or len(segments) == 0: return + # Align the start, end segment + if len(memory_transcript_segements) == 0 and len(segments) > 0: + start = segments[0]["start"] + if not segment_start or segment_start > start: + segment_start = start + for i, segment in enumerate(segments): + segment["start"] -= segment_start + segment["end"] -= segment_start + segments[i] = segment + asyncio.run_coroutine_threadsafe(websocket.send_json(segments), loop) threading.Thread(target=process_segments, args=(uid, segments)).start() @@ -178,6 +189,8 @@ def stream_audio(audio_buffer): websocket_active = True websocket_close_code = 1001 # Going Away, don't close with good from backend timer_start = None + segment_start = None + segment_end = None # audio_buffer = None duration = 0 try: @@ -377,6 +390,7 @@ async def _post_process_memory(memory: Memory): nonlocal processing_audio_frame_synced # Create wav + # TODO: remove audio frames [start, end] processing_audio_frame_synced = len(processing_audio_frames) file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[:processing_audio_frame_synced], From 60376d943063ebdc73a2f14697d730484a991eb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 13:33:02 +0700 Subject: [PATCH 69/88] Trim the audio frames regarding the start, end --- backend/routers/transcribe.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 9047867e4..29136022c 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -1,4 +1,5 @@ import threading +import math import asyncio import time from typing import List @@ -140,15 +141,23 @@ def stream_transcript(segments, stream_id): nonlocal processing_memory_synced nonlocal memory_transcript_segements nonlocal segment_start + nonlocal segment_end if not segments or len(segments) == 0: return # Align the start, end segment - if len(memory_transcript_segements) == 0 and len(segments) > 0: - start = segments[0]["start"] - if not segment_start or segment_start > start: + if len(segments) > 0: + # start + if not segment_start: + start = segments[0]["start"] segment_start = start + + # end + end = segments[-1]["end"] + if not segment_end or segment_end < end: + segment_end = end + for i, segment in enumerate(segments): segment["start"] -= segment_start segment["end"] -= segment_start @@ -388,12 +397,23 @@ async def _post_process_memory(memory: Memory): nonlocal processing_memory nonlocal processing_audio_frames nonlocal processing_audio_frame_synced + nonlocal segment_start + nonlocal segment_end # Create wav - # TODO: remove audio frames [start, end] processing_audio_frame_synced = len(processing_audio_frames) + + # Remove audio frames [start, end] + frames_per_sec = 100 + left = 0 + if segment_start: + left = max(left, math.floor(segment_start * frames_per_sec)) + right = processing_audio_frame_synced + if segment_end: + right = min(math.ceil(segment_end * frames_per_sec), right) + file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" - create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[:processing_audio_frame_synced], + create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[left:right], frame_rate=sample_rate, channels=channels, codec=codec, ) # Try merge new audio with the previous From 8010780ad11e8d985ce8e9d5f37f865e8c90a278 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 13:40:45 +0700 Subject: [PATCH 70/88] Create memory with delta segment start --- backend/utils/processing_memories.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/utils/processing_memories.py b/backend/utils/processing_memories.py index 4cff65121..f22f00dca 100644 --- a/backend/utils/processing_memories.py +++ b/backend/utils/processing_memories.py @@ -24,10 +24,11 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str print("Transcript segments is invalid") return timer_start = processing_memory.timer_start + segment_start = transcript_segments[0].start segment_end = transcript_segments[-1].end new_memory = CreateMemory( - started_at=datetime.fromtimestamp(timer_start, timezone.utc), - finished_at=datetime.fromtimestamp(timer_start + segment_end, timezone.utc), + started_at=datetime.fromtimestamp(timer_start + segment_start, timezone.utc), + finished_at=datetime.fromtimestamp(timer_start + segment_start + segment_end, timezone.utc), language=processing_memory.language, transcript_segments=transcript_segments, ) From a306ce5a925bc2c1467976b3dbb6be424df0dddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 13:44:19 +0700 Subject: [PATCH 71/88] Remove redundant len check on new segments --- backend/routers/transcribe.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 29136022c..acdff7c4b 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -147,16 +147,14 @@ def stream_transcript(segments, stream_id): return # Align the start, end segment - if len(segments) > 0: - # start - if not segment_start: - start = segments[0]["start"] - segment_start = start + if not segment_start: + start = segments[0]["start"] + segment_start = start - # end - end = segments[-1]["end"] - if not segment_end or segment_end < end: - segment_end = end + # end + end = segments[-1]["end"] + if not segment_end or segment_end < end: + segment_end = end for i, segment in enumerate(segments): segment["start"] -= segment_start @@ -200,6 +198,7 @@ def stream_audio(audio_buffer): timer_start = None segment_start = None segment_end = None + audio_frames_per_sec = 100 # audio_buffer = None duration = 0 try: @@ -404,13 +403,12 @@ async def _post_process_memory(memory: Memory): processing_audio_frame_synced = len(processing_audio_frames) # Remove audio frames [start, end] - frames_per_sec = 100 left = 0 if segment_start: - left = max(left, math.floor(segment_start * frames_per_sec)) + left = max(left, math.floor(segment_start * audio_frames_per_sec)) right = processing_audio_frame_synced if segment_end: - right = min(math.ceil(segment_end * frames_per_sec), right) + right = min(math.ceil(segment_end * audio_frames_per_sec), right) file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[left:right], From 978957d6e3cf3e634ed5604ad4ec82745b59ba34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 14:07:06 +0700 Subject: [PATCH 72/88] Use transcript timer segment start instead of timer start(audio) for started at, finished at memory --- backend/models/processing_memory.py | 1 + backend/routers/transcribe.py | 23 +++++++++++------------ backend/utils/processing_memories.py | 7 +++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/backend/models/processing_memory.py b/backend/models/processing_memory.py index c5ac29f08..11d812068 100644 --- a/backend/models/processing_memory.py +++ b/backend/models/processing_memory.py @@ -14,6 +14,7 @@ class ProcessingMemory(BaseModel): audio_url: Optional[str] = None created_at: datetime timer_start: float + timer_segment_start: Optional[float] = None timer_starts: List[float] = [] language: Optional[str] = None # applies only to Friend # TODO: once released migrate db to default 'en' transcript_segments: List[TranscriptSegment] = [] diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index acdff7c4b..a16bc595d 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -352,10 +352,11 @@ async def _create_processing_memory(): last_processing_memory_data = processing_memories_db.get_last(uid) if last_processing_memory_data: last_processing_memory = ProcessingMemory(**last_processing_memory_data) - segment_end = 0 + last_segment_end = 0 for segment in last_processing_memory.transcript_segments: - segment_end = max(segment_end, segment.end) - if last_processing_memory.timer_start + segment_end + min_seconds_limit > time.time(): + last_segment_end = max(last_segment_end, segment.end) + timer_segment_start = last_processing_memory.timer_segment_start if last_processing_memory.timer_segment_start else last_processing_memory.timer_start + if timer_segment_start + last_segment_end + min_seconds_limit > time.time(): processing_memory = last_processing_memory # Or create new @@ -364,6 +365,7 @@ async def _create_processing_memory(): id=str(uuid.uuid4()), created_at=datetime.now(timezone.utc), timer_start=timer_start, + timer_segment_start=timer_start+segment_start, language=language, ) @@ -557,6 +559,8 @@ async def _try_flush_new_memory_with_lock(time_validate: bool = True): async def _try_flush_new_memory(time_validate: bool = True): nonlocal memory_transcript_segements nonlocal timer_start + nonlocal segment_start + nonlocal segment_end nonlocal processing_memory nonlocal processing_memory_synced nonlocal processing_audio_frames @@ -567,13 +571,8 @@ async def _try_flush_new_memory(time_validate: bool = True): return # Validate last segment - last_segment = None - if len(memory_transcript_segements) > 0: - last_segment = memory_transcript_segements[-1] - if not last_segment: + if not segment_end: print("Not last segment or last segment invalid") - if last_segment: - print(f"{last_segment.dict()}") return # First chunk, create processing memory @@ -584,11 +583,11 @@ async def _try_flush_new_memory(time_validate: bool = True): # Validate transcript # Longer 120s - segment_end = last_segment.end now = time.time() should_create_memory_time = True if time_validate: - should_create_memory_time = timer_start + segment_end + min_seconds_limit < now + timer_segment_start = timer_start + segment_start + should_create_memory_time = timer_segment_start + segment_end + min_seconds_limit < now # 1 words at least should_create_memory_time_words = min_words_limit == 0 @@ -602,7 +601,7 @@ async def _try_flush_new_memory(time_validate: bool = True): should_create_memory = should_create_memory_time and should_create_memory_time_words print( - f"Should create memory {should_create_memory} - {timer_start} {segment_end} {min_seconds_limit} {now} - {time_validate}, session {session_id}") + f"Should create memory {should_create_memory} - {timer_segment_start} {segment_end} {min_seconds_limit} {now} - {time_validate}, session {session_id}") if should_create_memory: memory = await _create_memory() if not memory: diff --git a/backend/utils/processing_memories.py b/backend/utils/processing_memories.py index f22f00dca..eeaabb9b2 100644 --- a/backend/utils/processing_memories.py +++ b/backend/utils/processing_memories.py @@ -23,12 +23,11 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str if not transcript_segments or len(transcript_segments) == 0: print("Transcript segments is invalid") return - timer_start = processing_memory.timer_start - segment_start = transcript_segments[0].start + timer_segment_start = processing_memory.timer_segment_start segment_end = transcript_segments[-1].end new_memory = CreateMemory( - started_at=datetime.fromtimestamp(timer_start + segment_start, timezone.utc), - finished_at=datetime.fromtimestamp(timer_start + segment_start + segment_end, timezone.utc), + started_at=datetime.fromtimestamp(timer_segment_start, timezone.utc), + finished_at=datetime.fromtimestamp(timer_segment_start + segment_end, timezone.utc), language=processing_memory.language, transcript_segments=transcript_segments, ) From bd95410f8958c9c1965fde10582cda88f15aaae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 14:10:44 +0700 Subject: [PATCH 73/88] Backward compatbile with the old ProcessingMemory model --- backend/utils/processing_memories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/utils/processing_memories.py b/backend/utils/processing_memories.py index eeaabb9b2..057cbd0c6 100644 --- a/backend/utils/processing_memories.py +++ b/backend/utils/processing_memories.py @@ -23,7 +23,7 @@ async def create_memory_by_processing_memory(uid: str, processing_memory_id: str if not transcript_segments or len(transcript_segments) == 0: print("Transcript segments is invalid") return - timer_segment_start = processing_memory.timer_segment_start + timer_segment_start = processing_memory.timer_segment_start if processing_memory.timer_segment_start else processing_memory.timer_start segment_end = transcript_segments[-1].end new_memory = CreateMemory( started_at=datetime.fromtimestamp(timer_segment_start, timezone.utc), From 8465de643dd45c500f0f448fabbd4e5a28beadc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 16:59:21 +0700 Subject: [PATCH 74/88] Trim audio frames --- backend/routers/transcribe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index a16bc595d..62bb895cd 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -260,7 +260,8 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock # audio_file = open(path, "a") try: while websocket_active: - data = await websocket.receive_bytes() + raw_data = await websocket.receive_bytes() + data = raw_data[:] # audio_buffer.extend(data) if codec == 'opus' and sample_rate == 16000: data = decoder.decode(bytes(data), frame_size=160) @@ -283,7 +284,7 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock dg_socket2.send(data) # stream - stream_audio(data) + stream_audio(raw_data) # audio_buffer = bytearray() @@ -407,10 +408,10 @@ async def _post_process_memory(memory: Memory): # Remove audio frames [start, end] left = 0 if segment_start: - left = max(left, math.floor(segment_start * audio_frames_per_sec)) + left = max(left, math.floor(segment_start) * audio_frames_per_sec) right = processing_audio_frame_synced if segment_end: - right = min(math.ceil(segment_end * audio_frames_per_sec), right) + right = min(math.ceil(segment_end) * audio_frames_per_sec, right) file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" create_wav_from_bytes(file_path=file_path, frames=processing_audio_frames[left:right], @@ -585,8 +586,8 @@ async def _try_flush_new_memory(time_validate: bool = True): # Longer 120s now = time.time() should_create_memory_time = True + timer_segment_start = timer_start + segment_start if time_validate: - timer_segment_start = timer_start + segment_start should_create_memory_time = timer_segment_start + segment_end + min_seconds_limit < now # 1 words at least From 60b33c4be9a3cc93da7af270a634e34b66e884ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 17:09:43 +0700 Subject: [PATCH 75/88] Use timmer start to validate create new memory --- backend/routers/transcribe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 62bb895cd..a1ac62bfa 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -586,9 +586,8 @@ async def _try_flush_new_memory(time_validate: bool = True): # Longer 120s now = time.time() should_create_memory_time = True - timer_segment_start = timer_start + segment_start if time_validate: - should_create_memory_time = timer_segment_start + segment_end + min_seconds_limit < now + should_create_memory_time = timer_start + segment_end + min_seconds_limit < now # 1 words at least should_create_memory_time_words = min_words_limit == 0 @@ -602,7 +601,7 @@ async def _try_flush_new_memory(time_validate: bool = True): should_create_memory = should_create_memory_time and should_create_memory_time_words print( - f"Should create memory {should_create_memory} - {timer_segment_start} {segment_end} {min_seconds_limit} {now} - {time_validate}, session {session_id}") + f"Should create memory {should_create_memory} - {timer_start} {segment_end} {min_seconds_limit} {now} - {time_validate}, session {session_id}") if should_create_memory: memory = await _create_memory() if not memory: From 35027e28239f199e03ff0eec6b61594be0197534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 17:35:08 +0700 Subject: [PATCH 76/88] Update finished at for combined memory --- backend/database/memories.py | 5 +++++ backend/routers/transcribe.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/backend/database/memories.py b/backend/database/memories.py index 299cb6eb2..7bc30f07e 100644 --- a/backend/database/memories.py +++ b/backend/database/memories.py @@ -138,6 +138,11 @@ def update_memory_events(uid: str, memory_id: str, events: List[dict]): memory_ref = user_ref.collection('memories').document(memory_id) memory_ref.update({'structured.events': events}) +def update_memory_finished_at(uid: str, memory_id: str, finished_at: datetime): + user_ref = db.collection('users').document(uid) + memory_ref = user_ref.collection('memories').document(memory_id) + memory_ref.update({'finished_at': finished_at}) + # VISBILITY diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index a1ac62bfa..49a9463ad 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -524,6 +524,10 @@ async def _create_memory(): memories_db.update_memory_segments(uid, memory.id, [segment.dict() for segment in memory.transcript_segments]) + # Update finished at + memory.finished_at = datetime.fromtimestamp(memory.started_at.timestamp() + processing_memory.transcript_segments[-1].end, timezone.utc) + memories_db.update_memory_finished_at(uid, memory.id, memory.finished_at) + # Process memory = process_memory(uid, memory.language, memory, force_process=True) From 461d062bf1abc688ab921331f9d7fab2a368290e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sun, 22 Sep 2024 17:59:18 +0700 Subject: [PATCH 77/88] Add silent seconds between 2 combining audio --- backend/routers/transcribe.py | 3 ++- backend/utils/audio.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 49a9463ad..26a9267f3 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -428,7 +428,8 @@ async def _post_process_memory(memory: Memory): # merge merge_file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" - merge_wav_files(merge_file_path, [previous_file_path, file_path]) + nearest_timer_start = processing_memory.timer_starts[-2] + merge_wav_files(merge_file_path, [previous_file_path, file_path], [math.ceil(timer_start-nearest_timer_start), 0]) # clean os.remove(previous_file_path) diff --git a/backend/utils/audio.py b/backend/utils/audio.py index 069a2ebdd..bf4663c73 100644 --- a/backend/utils/audio.py +++ b/backend/utils/audio.py @@ -3,14 +3,16 @@ from pyogg import OpusDecoder from pydub import AudioSegment -def merge_wav_files(dest_file_path: str, source_files: [str]): +def merge_wav_files(dest_file_path: str, source_files: [str], silent_seconds: [int]): if len(source_files) == 0 or not dest_file_path: return combined_sounds = AudioSegment.empty() - for file_path in source_files: + for i in range(len(source_files)): + file_path = source_files[i] sound = AudioSegment.from_wav(file_path) - combined_sounds = combined_sounds + sound + silent_sec = silent_seconds[i] + combined_sounds = combined_sounds + sound + AudioSegment.silent(duration=silent_sec) combined_sounds.export(dest_file_path, format="wav") From 632403c2682bc8b0cc5f7f54f4f1333d3807de41 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sun, 22 Sep 2024 11:08:04 -0700 Subject: [PATCH 78/88] modal reduced time --- backend/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index 1b2362dcd..ecbc421ff 100644 --- a/backend/main.py +++ b/backend/main.py @@ -54,7 +54,7 @@ memory=(512, 1024), cpu=2, allow_concurrent_inputs=10, - timeout=60 * 35, + timeout=60 * 5, ) @asgi_app() def api(): From bff33877a3c1eb53c503a76de1297a81a0b720c9 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sun, 22 Sep 2024 11:31:24 -0700 Subject: [PATCH 79/88] script WER comparison --- .../stt/k_compare_transcripts_performance.py | 443 ++++++++++++++---- 1 file changed, 342 insertions(+), 101 deletions(-) diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index 5ffeabb35..b65d0f074 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -11,11 +11,15 @@ # - Include speechmatics to the game import json import os +import re +from collections import defaultdict +from itertools import islice from typing import Dict, List import firebase_admin from dotenv import load_dotenv from pydub import AudioSegment +from tabulate import tabulate load_dotenv('../../.dev.env') os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '../../' + os.getenv('GOOGLE_APPLICATION_CREDENTIALS') @@ -29,23 +33,6 @@ from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing -def store_model_result(memory_id: str, model: str, result: List[Dict]): - file_path = 'results.json' - if os.path.exists(file_path): - with open(file_path, 'r') as f: - results = json.load(f) - else: - results = {} - - if memory_id not in results: - results[memory_id] = {} - - results[memory_id][model] = result - # save it - with open(file_path, 'w') as f: - json.dump(results, f) - - def add_model_result_segments(model: str, new_segments: List[Dict], result: Dict): segments = [TranscriptSegment(**s) for s in result[model]] new_segments = [TranscriptSegment(**s) for s in new_segments] @@ -54,16 +41,118 @@ def add_model_result_segments(model: str, new_segments: List[Dict], result: Dict def execute_groq(file_path: str): + file_size = os.path.getsize(file_path) + print('execute_groq file_size', file_size / 1024 / 1024, 'MB') + split_files = [] + if file_size / 1024 / 1024 > 25: + # split file + aseg = AudioSegment.from_wav(file_path) + # split every 10 minutes + split_duration = 10 * 60 * 1000 + for i in range(0, len(aseg), split_duration): + split_file_path = f'{file_path}_{i}.wav' + split_files.append(split_file_path) + aseg[i:i + split_duration].export(split_file_path, format="wav") + else: + split_files.append(file_path) + client = Groq(api_key=os.getenv('GROQ_API_KEY')) + result = '' + for file_path in split_files: + with open(file_path, "rb") as file: + transcription = client.audio.transcriptions.create( + file=(file_path, file.read()), + model="whisper-large-v3", + response_format="verbose_json", + language="en", + temperature=0.0 + ) + result += ' ' + str(transcription) + return result.strip().lower().replace(' ', ' ') + + +async def _execute_single(file_path: str): + aseg = AudioSegment.from_wav(file_path) + duration = aseg.duration_seconds + memory_id = file_path.split('/')[-1].split('.')[0] + + if os.path.exists(f'results/{memory_id}.json'): + print('Already processed', memory_id) + return + if aseg.duration_seconds < 5: + print('Skipping', memory_id, 'duration', aseg.duration_seconds) + return + + print('Started processing', memory_id, 'duration', aseg.duration_seconds) + result = { + 'deepgram': [], + 'soniox': [], + 'speechmatics': [] + } + + def stream_transcript_deepgram(new_segments, _): + print('stream_transcript_deepgram', new_segments) + add_model_result_segments('deepgram', new_segments, result) + + def stream_transcript_soniox(new_segments, _): + print('stream_transcript_soniox', new_segments) + add_model_result_segments('soniox', new_segments, result) + + def stream_transcript_speechmatics(new_segments, _): + print('stream_transcript_speechmatics', new_segments) + add_model_result_segments('speechmatics', new_segments, result) + + # streaming models + socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0) + socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 'en', None) + socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 'en', 0) + print('duration', duration) with open(file_path, "rb") as file: - transcription = client.audio.transcriptions.create( - file=(file_path, file.read()), - model="whisper-large-v3", - response_format="text", - language="en", - temperature=0.0 - ) - return str(transcription) + while True: + chunk = file.read(320) + if not chunk: + break + socket.send(bytes(chunk)) + await socket_soniox.send(bytes(chunk)) + await socket_speechmatics.send(bytes(chunk)) + await asyncio.sleep(0.005) + + print('Finished sending audio') + groq_result: str = execute_groq(file_path) # source of truth + result['whisper-large-v3'] = groq_result + + # whisperx + try: + signed_url = upload_postprocessing_audio(file_path) + words = fal_whisperx(signed_url) + fal_segments = fal_postprocessing(words, duration) + result['fal_whisperx'] = [s.dict() for s in fal_segments] + except Exception as e: + print('fal_whisperx', e) + result['fal_whisperx'] = [] + + print('Waiting for sockets to finish', min(60, duration), 'seconds') + await asyncio.sleep(min(30, duration)) + + os.makedirs('results', exist_ok=True) + with open(f'results/{memory_id}.json', 'w') as f: + json.dump(result, f, indent=2) + + socket.finish() + await socket_soniox.close() + await socket_speechmatics.close() + + +def batched(iterable, n): + """ + Generator that yields lists of size 'n' from 'iterable'. + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch async def process_memories_audio_files(): @@ -71,93 +160,245 @@ async def process_memories_audio_files(): for uid in uids: memories = os.listdir(f'_temp2/{uid}') memories = [f'_temp2/{uid}/{memory}' for memory in memories] - # memories_id = [] - # for file_path in memories: - # if AudioSegment.from_wav(file_path).frame_rate != 16000: - # continue - # memory_id = file_path.split('.')[0] - # memories_id.append(memory_id) - - # memories_data = get_memories_by_id(uid, memories_id) - for file_path in memories: - aseg = AudioSegment.from_wav(file_path) - memory_id = file_path.split('/')[-1].split('.')[0] - - if os.path.exists(f'results/{memory_id}.json'): - print('Already processed', memory_id) - continue - print('Started processing', memory_id, 'duration', aseg.duration_seconds) - result = { - 'deepgram': [], - 'soniox': [], - 'speechmatics': [] - } - - def stream_transcript_deepgram(new_segments, _): - print('stream_transcript_deepgram', new_segments) - add_model_result_segments('deepgram', new_segments, result) - - def stream_transcript_soniox(new_segments, _): - print('stream_transcript_soniox', new_segments) - add_model_result_segments('soniox', new_segments, result) - - def stream_transcript_speechmatics(new_segments, _): - print('stream_transcript_speechmatics', new_segments) - add_model_result_segments('speechmatics', new_segments, result) - - groq_result: str = execute_groq(file_path) # source of truth - result['whisper-large-v3'] = groq_result - - # whisperx - signed_url = upload_postprocessing_audio(file_path) - words = fal_whisperx(signed_url) - fal_segments = fal_postprocessing(words, aseg.duration_seconds) - result['fal_whisperx'] = [s.dict() for s in fal_segments] - - # streaming models - socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0) - socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 'en', None) - socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 'en', 0) - duration = AudioSegment.from_wav(file_path).duration_seconds - print('duration', duration) - with open(file_path, "rb") as file: - while True: - chunk = file.read(320) - if not chunk: - break - socket.send(bytes(chunk)) - await socket_soniox.send(bytes(chunk)) - await socket_speechmatics.send(bytes(chunk)) - await asyncio.sleep(0.001) - await asyncio.sleep(duration) - - os.makedirs('results', exist_ok=True) - with open(f'results/{memory_id}.json', 'w') as f: - json.dump(result, f, indent=2) - break + # batch_size = 5 + for memory in memories: + await _execute_single(memory) + # for batch_num, batch in enumerate(batched(memories, batch_size), start=1): + # tasks = [asyncio.create_task(_execute_single(file_path)) for file_path in batch] + # await asyncio.gather(*tasks) + # print(f'Batch {batch_num} processed') from jiwer import wer def compute_wer(): - dir = 'results/' - for file in os.listdir(dir): + """ + Computes the Word Error Rate (WER) for each transcription model against a reference model + across all JSON files in the specified directory. Outputs detailed results and overall rankings. + """ + dir_path = 'results/' # Directory containing JSON files + reference_model = 'whisper-large-v3' # Reference model key + table_data = [] # List to hold detailed table rows + wer_accumulator = defaultdict(list) # To accumulate WERs per model + points_counter = defaultdict(int) # To count points per model based on WER rankings + + # Define detailed table headers + detailed_headers = [ + "File", + "Model", + "WER", + "Source Words", + "Model Words", + "Source Characters", + "Model Characters", + "Transcript" + ] + + # Check if the directory exists + if not os.path.isdir(dir_path): + print(f"Directory '{dir_path}' does not exist.") + return + + # Iterate through all files in the specified directory + for file in os.listdir(dir_path): if not file.endswith('.json'): - continue - with open(f'{dir}{file}', 'r') as f: - result = json.load(f) - source = str(result['whisper-large-v3']).strip().lower().replace(' ', ' ') - print(file) + continue # Skip non-JSON files + + file_path = os.path.join(dir_path, file) + with open(file_path, 'r', encoding='utf-8') as f: + try: + result = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in file: {file}") + continue # Skip files with invalid JSON + + # Check if the reference model exists in the JSON + if reference_model not in result: + print(f"Reference model '{reference_model}' not found in file: {file}") + continue # Skip files without the reference model + + # Assemble the reference transcript + reference_text = regex_fix(result.get(reference_model, '')) + if isinstance(reference_text, list): + # If reference_text is a list of segments + reference_text = ' '.join([segment.get('text', '') for segment in reference_text]).strip().lower() + else: + # If reference_text is a single string + reference_text = str(reference_text).strip().lower() + reference_text = ' '.join(reference_text.split()) # Normalize whitespace + + # Calculate source words and characters + source_words = len(reference_text.split()) + source_characters = len(reference_text) + + print(f"Processing file: {file}") + + # Temporary storage for current file's model WERs to determine ranking points + current_file_wer = {} + + # Iterate through each model in the JSON for model, segments in result.items(): - if model == 'whisper-large-v3': # TODO: words vs each other - continue - segments_str = ' '.join([s['text'] for s in segments]).strip().lower().replace(' ', ' ') - value = wer(source, segments_str) - print(f'{model} WER: {value}') + if model == reference_model: + model_text = reference_text # Reference model's transcript + else: + if isinstance(segments, list): + # Assemble the model's transcript from segments + model_text = ' '.join([segment.get('text', '') for segment in segments]).strip().lower() + else: + # If segments is a single string + model_text = str(segments).strip().lower() + model_text = ' '.join(model_text.split()) # Normalize whitespace + + # Calculate model words and characters + model_words = len(model_text.split()) + model_characters = len(model_text) + + # Compute WER + current_wer = wer(reference_text, model_text) + + # Accumulate WER for overall statistics (exclude reference model) + if model != reference_model: + wer_accumulator[model].append(current_wer) + + # Store WER for current file's ranking + if model != reference_model: + current_file_wer[model] = current_wer + + # Append the data to the detailed table + table_data.append([ + file, + model, + f"{current_wer:.2%}", + source_words, + model_words, + source_characters, + model_characters, + model_text + ]) + + # Determine which model(s) had the lowest WER in the current file + if current_file_wer: + min_wer = min(current_file_wer.values()) + best_models = [model for model, w in current_file_wer.items() if w == min_wer] + for model in best_models: + points_counter[model] += 1 # Assign 1 point to each best model + print('-----------------------------------------') + # Generate the detailed WER table using tabulate + if table_data: + print("\nDetailed WER Results:") + detailed_table = tabulate(table_data, headers=detailed_headers, tablefmt="grid", stralign="left") + with open('results/detailed_wer.txt', 'w') as f: + f.write(detailed_table) + else: + print("No data to display.") + + # Compute overall WER per model (average) + overall_wer = {} + for model, wer_list in wer_accumulator.items(): + if wer_list: + overall_wer[model] = sum(wer_list) / len(wer_list) + + # Create a list for overall WER table + overall_wer_table = [] + for model, avg_wer in overall_wer.items(): + overall_wer_table.append([ + model, + f"{avg_wer:.2%}" + ]) + + # Sort the overall WER table by average WER ascending (lower is better) + overall_wer_table_sorted = sorted(overall_wer_table, key=lambda x: x[1]) + + # Define overall WER table headers + overall_wer_headers = ["Model", "Average WER"] + + # Generate the overall WER table + if overall_wer_table_sorted: + print("\nOverall WER per Model:") + overall_wer_formatted = tabulate(overall_wer_table_sorted, headers=overall_wer_headers, tablefmt="grid", + stralign="left") + print(overall_wer_formatted) + with open('results/wer.txt', 'w') as f: + f.write(overall_wer_formatted) + else: + print("No overall WER data to display.") + + # Create a ranking table based on points + ranking_table = [] + for model, points in points_counter.items(): + ranking_table.append([ + model, + points + ]) + + # Sort the ranking table by points descending (more points are better) + ranking_table_sorted = sorted(ranking_table, key=lambda x: x[1], reverse=True) + + # Assign rankings + ranking_table_with_rank = [] + current_rank = 1 + previous_points = None + for idx, (model, points) in enumerate(ranking_table_sorted): + if points != previous_points: + rank = current_rank + else: + rank = current_rank - 1 # Same rank as previous + ranking_table_with_rank.append([ + rank, + model, + points + ]) + previous_points = points + current_rank += 1 + + # Define ranking table headers + ranking_headers = ["Rank", "Model", "Points"] + + # Generate the ranking table + if ranking_table_with_rank: + print("\nModel Rankings Based on WER Performance:") + ranking_table_formatted = tabulate(ranking_table_with_rank, headers=ranking_headers, tablefmt="grid", + stralign="left") + print(ranking_table_formatted) + with open('results/ranking.txt', 'w') as f: + f.write(ranking_table_formatted) + else: + print("No ranking data to display.") + + +def regex_fix(text: str): + # Define the regular expression + pattern = r'(?<=transcription\(text=["\'])(.*?)(?=["\'],\s*task=)' + + # Search for the pattern in the data + match = re.search(pattern, text) + + # If a match is found, extract and print the text + if match: + extracted_text = match.group(0) + return extracted_text + else: + print("No match found.") + return text + if __name__ == '__main__': # asyncio.run(process_memories_audio_files()) compute_wer() + # client = Groq(api_key=os.getenv('GROQ_API_KEY')) + # file_path = '_temp2/DX8n89KAmUaG9O7Qvj8xTi81Zu12/0bce5547-675b-4dea-b9fe-cfb69740100b.wav' + + # with open(file_path, "rb") as file: + # transcription = client.audio.transcriptions.create( + # file=(file_path, file.read()), + # model="whisper-large-v3", + # response_format="verbose_json", + # language="en", + # temperature=0.0 + # ) + # print(transcription) + # for segment in transcription.segments: + # print(segment['start'], segment['end'], segment['text']) From 1e65fd9098224c9d49156c1d92d3539b12d61b5e Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sun, 22 Sep 2024 14:17:58 -0700 Subject: [PATCH 80/88] gitignore update + Diarization error rate compute --- .gitignore | 7 +- backend/main.py | 32 ++- .../stt/k_compare_transcripts_performance.py | 219 ++++++++++++++++-- 3 files changed, 233 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index e49c2b381..0d1cd6f1e 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,9 @@ app/lib/env/prod_env.g.dart /backend/scripts/research/*.md /backend/scripts/research/*.json /backend/scripts/research/*.csv -/backend/scripts/research/users \ No newline at end of file +/backend/scripts/research/users +/backend/scripts/stt/_temp +/backend/scripts/stt/_temp2 +/backend/scripts/stt/pretrained_models +/backend/scripts/stt/results +/backend/scripts/stt/diarization.json diff --git a/backend/main.py b/backend/main.py index ecbc421ff..f5f2c57f0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,8 +6,7 @@ from modal import Image, App, asgi_app, Secret, Cron from routers import workflow, chat, firmware, plugins, memories, transcribe, notifications, speech_profile, \ - agents, facts, users, postprocessing, processing_memories, trends,sdcard - + agents, facts, users, postprocessing, processing_memories, trends, sdcard from utils.other.notifications import start_cron_job if os.environ.get('SERVICE_ACCOUNT_JSON'): @@ -54,7 +53,7 @@ memory=(512, 1024), cpu=2, allow_concurrent_inputs=10, - timeout=60 * 5, + timeout=60 * 19, ) @asgi_app() def api(): @@ -70,3 +69,30 @@ def api(): @modal_app.function(image=image, schedule=Cron('* * * * *')) async def notifications_cronjob(): await start_cron_job() + + +@app.post('/webhook') +async def webhook(data: dict): + diarization = data['output']['diarization'] + joined = [] + for speaker in diarization: + if not joined: + joined.append(speaker) + else: + if speaker['speaker'] == joined[-1]['speaker']: + joined[-1]['end'] = speaker['end'] + else: + joined.append(speaker) + + print(data['jobId'], json.dumps(joined)) + # openn scripts/stt/diarization.json, get jobId=memoryId, delete but get memoryId, and save memoryId=joined + with open('scripts/stt/diarization.json', 'r') as f: + diarization_data = json.loads(f.read()) + + memory_id = diarization_data.get(data['jobId']) + if memory_id: + diarization_data[memory_id] = joined + del diarization_data[data['jobId']] + with open('scripts/stt/diarization.json', 'w') as f: + json.dump(diarization_data, f, indent=2) + return 'ok' diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index b65d0f074..3f193cc97 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -17,6 +17,7 @@ from typing import Dict, List import firebase_admin +import requests from dotenv import load_dotenv from pydub import AudioSegment from tabulate import tabulate @@ -63,7 +64,7 @@ def execute_groq(file_path: str): transcription = client.audio.transcriptions.create( file=(file_path, file.read()), model="whisper-large-v3", - response_format="verbose_json", + response_format="text", language="en", temperature=0.0 ) @@ -370,13 +371,9 @@ def compute_wer(): def regex_fix(text: str): - # Define the regular expression + """Fix some of the stored JSON in results/$id.json from the Groq API.""" pattern = r'(?<=transcription\(text=["\'])(.*?)(?=["\'],\s*task=)' - - # Search for the pattern in the data match = re.search(pattern, text) - - # If a match is found, extract and print the text if match: extracted_text = match.group(0) return extracted_text @@ -385,20 +382,200 @@ def regex_fix(text: str): return text +def pyannote_diarize(file_path: str): + memory_id = file_path.split('/')[-1].split('.')[0] + with open('diarization.json', 'r') as f: + results = json.loads(f.read()) + + if memory_id in results: + print('Already diarized', memory_id) + return + + url = "https://api.pyannote.ai/v1/diarize" + headers = {"Authorization": f"Bearer {os.getenv('PYANNOTE_API_KEY')}"} + webhook = 'https://camel-lucky-reliably.ngrok-free.app/webhook' + signed_url = upload_postprocessing_audio(file_path) + data = {'webhook': webhook, 'url': signed_url} + response = requests.post(url, headers=headers, json=data) + print(memory_id, response.json()['jobId']) + # update diarization.json, and set jobId=memoryId + with open('diarization.json', 'r') as f: + diarization = json.loads(f.read()) + + diarization[response.json()['jobId']] = memory_id + with open('diarization.json', 'w') as f: + json.dump(diarization, f, indent=2) + + +def generate_diarizations(): + uids = os.listdir('_temp2') + for uid in uids: + memories = os.listdir(f'_temp2/{uid}') + memories = [f'_temp2/{uid}/{memory}' for memory in memories] + for memory in memories: + memory_id = memory.split('/')[-1].split('.')[0] + if os.path.exists(f'results/{memory_id}.json'): + pyannote_diarize(memory) + else: + print('Skipping', memory_id) + + +from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.core import Annotation, Segment + +der_metric = DiarizationErrorRate() + + +def compute_der(): + """ + Computes the Diarization Error Rate (DER) for each model across all JSON files in the 'results/' directory. + Outputs a summary table and rankings to 'der_report.txt'. + """ + dir_path = 'results/' # Directory containing result JSON files and 'diarization.json' + output_file = os.path.join(dir_path, 'der_report.txt') # Output report file + excluded_model = 'whisper-large-v3' # Model to exclude from analysis + + # Initialize DER metric + der_metric = DiarizationErrorRate() + + # Check if the directory exists + if not os.path.isdir(dir_path): + print(f"Directory '{dir_path}' does not exist.") + return + + # Path to 'diarization.json' + diarization_path = 'diarization.json' + + # Load reference diarization data + with open(diarization_path, 'r', encoding='utf-8') as f: + try: + diarization = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in 'diarization.json'.") + return + + # Prepare to collect DER results + der_results = [] # List to store [Memory ID, Model, DER] + model_der_accumulator = defaultdict(list) # To calculate average DER per model + + # Iterate through all JSON files in 'results/' directory + for file in os.listdir(dir_path): + if not file.endswith('.json') or file == 'diarization.json': + continue # Skip non-JSON files and 'diarization.json' itself + + memory_id = file.split('.')[0] # Extract memory ID from filename + + # Check if memory_id exists in 'diarization.json' + if memory_id not in diarization: + print(f"Memory ID '{memory_id}' not found in 'diarization.json'. Skipping file: {file}") + continue + + # Load reference segments for the current memory_id + ref_segments = diarization[memory_id] + ref_annotation = Annotation() + for seg in ref_segments: + speaker, start, end = seg['speaker'], seg['start'], seg['end'] + ref_annotation[Segment(start, end)] = speaker + + # Load hypothesis segments from the result JSON file + file_path = os.path.join(dir_path, file) + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in file: {file}. Skipping.") + continue + + # Iterate through each model's segments in the result + for model, segments in data.items(): + if model == excluded_model: + continue # Skip the excluded model + + hyp_annotation = Annotation() + for seg in segments: + speaker, start, end = seg['speaker'], seg['start'], seg['end'] + # Optional: Normalize speaker labels if necessary + if speaker == 'SPEAKER_0': + speaker = 'SPEAKER_00' + elif speaker == 'SPEAKER_1': + speaker = 'SPEAKER_01' + elif speaker == 'SPEAKER_2': + speaker = 'SPEAKER_02' + elif speaker == 'SPEAKER_3': + speaker = 'SPEAKER_03' + hyp_annotation[Segment(start, end)] = speaker + + # Compute DER between reference and hypothesis + der = der_metric(ref_annotation, hyp_annotation) + + # Store the result + der_results.append([memory_id, model, f"{der:.2%}"]) + model_der_accumulator[model].append(der) + + # Generate the detailed DER table + der_table = tabulate(der_results, headers=["Memory ID", "Model", "DER"], tablefmt="grid", stralign="left") + + # Calculate average DER per model + average_der = [] + for model, ders in model_der_accumulator.items(): + avg = sum(ders) / len(ders) + average_der.append([model, f"{avg:.2%}"]) + + # Sort models by average DER ascending (lower is better) + average_der_sorted = sorted(average_der, key=lambda x: float(x[1].strip('%'))) + + # Determine the winner (model with the lowest average DER) + winner = average_der_sorted[0][0] if average_der_sorted else "N/A" + + # Prepare rankings (1st, 2nd, etc.) + rankings = [] + rank = 1 + previous_der = None + for model, avg in average_der_sorted: + current_der = float(avg.strip('%')) + if previous_der is None or current_der < previous_der: + current_rank = rank + else: + current_rank = rank - 1 # Same rank as previous if DER is equal + rankings.append([current_rank, model, avg]) + previous_der = current_der + rank += 1 + + # Generate the rankings table + ranking_table = tabulate(rankings, headers=["Rank", "Model", "Average DER"], tablefmt="grid", stralign="left") + + # Write all results to the output file + with open(output_file, 'w', encoding='utf-8') as out_f: + out_f.write("Diarization Error Rate (DER) Analysis Report\n") + out_f.write("=" * 50 + "\n\n") + out_f.write("Detailed DER Results:\n") + out_f.write(der_table + "\n\n") + out_f.write("Average DER per Model:\n") + out_f.write( + tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left") + "\n\n") + out_f.write("Model Rankings Based on Average DER:\n") + out_f.write(ranking_table + "\n\n") + out_f.write(f"Winner: {winner}\n") + + # Print a confirmation message + print(f"Diarization Error Rate (DER) analysis completed. Report saved to '{output_file}'.") + + # Optionally, print the tables to the console as well + if der_results: + print("\nDetailed DER Results:") + print(der_table) + if average_der_sorted: + print("\nAverage DER per Model:") + print(tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left")) + if rankings: + print("\nModel Rankings Based on Average DER:") + print(ranking_table) + print(f"\nWinner: {winner}") + + if __name__ == '__main__': # asyncio.run(process_memories_audio_files()) - compute_wer() - # client = Groq(api_key=os.getenv('GROQ_API_KEY')) - # file_path = '_temp2/DX8n89KAmUaG9O7Qvj8xTi81Zu12/0bce5547-675b-4dea-b9fe-cfb69740100b.wav' - - # with open(file_path, "rb") as file: - # transcription = client.audio.transcriptions.create( - # file=(file_path, file.read()), - # model="whisper-large-v3", - # response_format="verbose_json", - # language="en", - # temperature=0.0 - # ) - # print(transcription) - # for segment in transcription.segments: - # print(segment['start'], segment['end'], segment['text']) + # generate_diarizations() + + # compute_wer() + compute_der() From 0f704e6efcad49d30391ae68b034928b4291f436 Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sun, 22 Sep 2024 15:04:07 -0700 Subject: [PATCH 81/88] speechmatics defaulted to deepgram --- backend/routers/postprocessing.py | 2 + backend/routers/transcribe.py | 45 +++++++++---------- .../stt/k_compare_transcripts_performance.py | 4 +- backend/utils/stt/streaming.py | 31 +++++-------- 4 files changed, 36 insertions(+), 46 deletions(-) diff --git a/backend/routers/postprocessing.py b/backend/routers/postprocessing.py index ef236fc07..55eac26e7 100644 --- a/backend/routers/postprocessing.py +++ b/backend/routers/postprocessing.py @@ -27,6 +27,8 @@ def postprocess_memory( TODO: post llm process here would be great, sometimes whisper x outputs without punctuation """ + # TODO: this pipeline vs groq+pyannote diarization 3.1, probably the latter is better. + # Save file file_path = f"_temp/{memory_id}_{file.filename}" with open(file_path, 'wb') as f: diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 26a9267f3..bd6bd9c51 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -1,8 +1,5 @@ -import threading import math -import asyncio -import time -from typing import List +import threading import uuid from datetime import datetime, timezone from enum import Enum @@ -16,7 +13,6 @@ import database.memories as memories_db import database.processing_memories as processing_memories_db -from database.redis_db import get_user_speech_profile from models.memory import Memory, TranscriptSegment from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated from models.processing_memory import ProcessingMemory @@ -96,15 +92,12 @@ async def _websocket_util( print('websocket_endpoint', uid, language, sample_rate, codec, channels, include_speech_profile, new_memory_watch, stt_service) - if stt_service == STTService.soniox and ( - sample_rate != 16000 or codec != 'opus' or language not in soniox_valid_languages): - stt_service = STTService.deepgram - if stt_service == STTService.speechmatics and (sample_rate != 16000 or codec != 'opus'): - stt_service = STTService.deepgram + if stt_service == STTService.soniox and language not in soniox_valid_languages: + stt_service = STTService.deepgram # defaults to deepgram - # At some point try running all the models together to easily compare + if stt_service == STTService.speechmatics: # defaults to deepgram (no credits + 10 connections max limit) + stt_service = STTService.deepgram - # Check: Why do we need try-catch around websocket.accept? try: await websocket.accept() except RuntimeError as e: @@ -203,30 +196,32 @@ def stream_audio(audio_buffer): duration = 0 try: file_path, duration = None, 0 + # TODO: how bee does for recognizing other languages speech profile if language == 'en' and (codec == 'opus' or codec == 'pcm16') and include_speech_profile: file_path = get_profile_audio_if_exists(uid) duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0 - # Deepgram + # DEEPGRAM if stt_service == STTService.deepgram: - deepgram_codec_value = 'pcm16' if codec == 'opus' else codec deepgram_socket = await process_audio_dg( - stream_transcript, memory_stream_id, language, sample_rate, deepgram_codec_value, channels, - preseconds=duration + stream_transcript, memory_stream_id, language, sample_rate, channels, preseconds=duration ) if duration: deepgram_socket2 = await process_audio_dg( - stream_transcript, speech_profile_stream_id, language, sample_rate, deepgram_codec_value, channels + stream_transcript, speech_profile_stream_id, language, sample_rate, channels ) await send_initial_file_path(file_path, deepgram_socket) + # SONIOX elif stt_service == STTService.soniox: soniox_socket = await process_audio_soniox( - stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None + stream_transcript, speech_profile_stream_id, sample_rate, language, + uid if include_speech_profile else None ) + # SPEECHMATICS elif stt_service == STTService.speechmatics: speechmatics_socket = await process_audio_speechmatics( - stream_transcript, speech_profile_stream_id, language, preseconds=duration + stream_transcript, speech_profile_stream_id, sample_rate, language, preseconds=duration ) if duration: await send_initial_file_path(file_path, speechmatics_socket) @@ -366,7 +361,7 @@ async def _create_processing_memory(): id=str(uuid.uuid4()), created_at=datetime.now(timezone.utc), timer_start=timer_start, - timer_segment_start=timer_start+segment_start, + timer_segment_start=timer_start + segment_start, language=language, ) @@ -429,7 +424,8 @@ async def _post_process_memory(memory: Memory): # merge merge_file_path = f"_temp/{memory.id}_{uuid.uuid4()}_be" nearest_timer_start = processing_memory.timer_starts[-2] - merge_wav_files(merge_file_path, [previous_file_path, file_path], [math.ceil(timer_start-nearest_timer_start), 0]) + merge_wav_files(merge_file_path, [previous_file_path, file_path], + [math.ceil(timer_start - nearest_timer_start), 0]) # clean os.remove(previous_file_path) @@ -498,8 +494,8 @@ async def _create_memory(): memory = None messages = [] if not processing_memory.memory_id: - (new_memory, new_messages, updated_processing_memory) = await create_memory_by_processing_memory(uid, - processing_memory.id) + new_memory, new_messages, updated_processing_memory = await create_memory_by_processing_memory( + uid, processing_memory.id) if not new_memory: print("Can not create new memory") @@ -526,7 +522,8 @@ async def _create_memory(): [segment.dict() for segment in memory.transcript_segments]) # Update finished at - memory.finished_at = datetime.fromtimestamp(memory.started_at.timestamp() + processing_memory.transcript_segments[-1].end, timezone.utc) + memory.finished_at = datetime.fromtimestamp( + memory.started_at.timestamp() + processing_memory.transcript_segments[-1].end, timezone.utc) memories_db.update_memory_finished_at(uid, memory.id, memory.finished_at) # Process diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index 3f193cc97..61c06888a 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -105,8 +105,8 @@ def stream_transcript_speechmatics(new_segments, _): # streaming models socket = await process_audio_dg(stream_transcript_deepgram, '1', 'en', 16000, 'pcm16', 1, 0) - socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 'en', None) - socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 'en', 0) + socket_soniox = await process_audio_soniox(stream_transcript_soniox, '1', 16000, 'en', None) + socket_speechmatics = await process_audio_speechmatics(stream_transcript_speechmatics, '1', 16000, 'en', 0) print('duration', duration) with open(file_path, "rb") as file: while True: diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index e111c825c..be14b8c82 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -94,10 +94,10 @@ async def send_initial_file(data: List[List[int]], transcript_socket): async def process_audio_dg( - stream_transcript, stream_id: int, language: str, sample_rate: int, codec: str, channels: int, + stream_transcript, stream_id: int, language: str, sample_rate: int, channels: int, preseconds: int = 0, ): - print('process_audio_dg', language, sample_rate, codec, channels, preseconds) + print('process_audio_dg', language, sample_rate, channels, preseconds) def on_message(self, result, **kwargs): # print(f"Received message from Deepgram") # Log when message is received @@ -143,7 +143,7 @@ def on_error(self, error, **kwargs): print(f"Error: {error}") print("Connecting to Deepgram") # Log before connection attempt - return connect_to_deepgram(on_message, on_error, language, sample_rate, codec, channels) + return connect_to_deepgram(on_message, on_error, language, sample_rate, channels) def process_segments(uid: str, segments: list[dict]): @@ -151,7 +151,7 @@ def process_segments(uid: str, segments: list[dict]): trigger_realtime_integrations(uid, token, segments) -def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, codec: str, channels: int): +def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, channels: int): # 'wss://api.deepgram.com/v1/listen?encoding=linear16&sample_rate=8000&language=$recordingsLanguage&model=nova-2-general&no_delay=true&endpointing=100&interim_results=false&smart_format=true&diarize=true' try: dg_connection = deepgram.listen.live.v("1") @@ -171,7 +171,7 @@ def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, c multichannel=channels > 1, model='nova-2-general', sample_rate=sample_rate, - encoding='linear16' if codec == 'pcm8' or codec == 'pcm16' else 'opus' + encoding='linear16' ) result = dg_connection.start(options) print('Deepgram connection started:', result) @@ -183,10 +183,7 @@ def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, c soniox_valid_languages = ['en'] -# soniox_valid_languages = ['en', 'es', 'fr', 'ko', 'zh', 'it', 'pt', 'de'] - - -async def process_audio_soniox(stream_transcript, stream_id: int, language: str, uid: str): +async def process_audio_soniox(stream_transcript, stream_id: int, sample_rate: int, language: str, uid: str): # Fuck, soniox doesn't even support diarization in languages != english api_key = os.getenv('SONIOX_API_KEY') if not api_key: @@ -198,12 +195,12 @@ async def process_audio_soniox(stream_transcript, stream_id: int, language: str, if language not in soniox_valid_languages: raise ValueError(f"Unsupported language '{language}'. Supported languages are: {soniox_valid_languages}") - has_speech_profile = create_user_speech_profile(uid) if uid else False # only english too + has_speech_profile = create_user_speech_profile(uid) if uid and sample_rate == 16000 else False # only english too # Construct the initial request with all required and optional parameters request = { 'api_key': api_key, - 'sample_rate_hertz': 16000, + 'sample_rate_hertz': sample_rate, 'include_nonfinal': True, 'enable_endpoint_detection': True, 'enable_streaming_speaker_diarization': True, @@ -300,12 +297,9 @@ async def on_message(): CONNECTION_URL = f"wss://eu2.rt.speechmatics.com/v2" -async def process_audio_speechmatics(stream_transcript, stream_id: int, language: str, preseconds: int = 0): - # Create a transcription client +async def process_audio_speechmatics(stream_transcript, stream_id: int, sample_rate: int, language: str, preseconds: int = 0): api_key = os.getenv('SPEECHMATICS_API_KEY') uri = 'wss://eu2.rt.speechmatics.com/v2' - # Validate the language and construct the model name - # has_speech_profile = create_user_speech_profile(uid) # only english too request = { "message": "StartRecognition", @@ -319,7 +313,7 @@ async def process_audio_speechmatics(stream_transcript, stream_id: int, language "enable_entities": True, "speaker_diarization_config": {"max_speakers": 4} }, - "audio_format": {"type": "raw", "encoding": "pcm_s16le", "sample_rate": 16000}, + "audio_format": {"type": "raw", "encoding": "pcm_s16le", "sample_rate": sample_rate}, # "audio_events_config": { # "types": [ # "laughter", @@ -329,16 +323,13 @@ async def process_audio_speechmatics(stream_transcript, stream_id: int, language # } } try: - # Connect to Soniox WebSocket print("Connecting to Speechmatics WebSocket...") socket = await websockets.connect(uri, extra_headers={"Authorization": f"Bearer {api_key}"}) print("Connected to Speechmatics WebSocket.") - # Send the initial request await socket.send(json.dumps(request)) print(f"Sent initial request: {request}") - # Start listening for messages from Soniox async def on_message(): try: async for message in socket: @@ -370,7 +361,7 @@ async def on_message(): is_user = True if r_speaker == '1' and preseconds > 0 else False if r_start < preseconds: - print('Skipping word', r_start, r_content) + # print('Skipping word', r_start, r_content) continue # print(r_content, r_speaker, [r_start, r_end]) if not segments: From 3d3538e438fb075d5c3636d0510567424ee7fc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Mon, 23 Sep 2024 06:07:25 +0700 Subject: [PATCH 82/88] Updates modal timeout to 10m --- backend/main.py | 2 +- backend/routers/transcribe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index 1b2362dcd..ab21cdb0a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -54,7 +54,7 @@ memory=(512, 1024), cpu=2, allow_concurrent_inputs=10, - timeout=60 * 35, + timeout=60 * 10, ) @asgi_app() def api(): diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 26a9267f3..b001bcf58 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -132,7 +132,7 @@ async def _websocket_util( loop = asyncio.get_event_loop() # Soft timeout, should < MODAL_TIME_OUT - 3m - timeout_seconds = 1800 # 30m + timeout_seconds = 420 # 7m started_at = time.time() def stream_transcript(segments, stream_id): From a4f8f3bdda869cc7eeafeb815c33e2e14d27a21a Mon Sep 17 00:00:00 2001 From: Joan Cabezas Date: Sun, 22 Sep 2024 16:13:44 -0700 Subject: [PATCH 83/88] set modal container to 10 min max --- backend/main.py | 2 +- backend/routers/transcribe.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index f5f2c57f0..c3fa83b06 100644 --- a/backend/main.py +++ b/backend/main.py @@ -53,7 +53,7 @@ memory=(512, 1024), cpu=2, allow_concurrent_inputs=10, - timeout=60 * 19, + timeout=60 * 10, ) @asgi_app() def api(): diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index bd6bd9c51..55b85c7a6 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -74,6 +74,8 @@ class STTService(str, Enum): soniox = "soniox" speechmatics = "speechmatics" + # auto = "auto" + @staticmethod def get_model_name(value): if value == STTService.deepgram: @@ -98,6 +100,9 @@ async def _websocket_util( if stt_service == STTService.speechmatics: # defaults to deepgram (no credits + 10 connections max limit) stt_service = STTService.deepgram + # TODO: if language english, use soniox + # TODO: else deepgram, if speechmatics credits, prob this for both? + try: await websocket.accept() except RuntimeError as e: From 5ccdf90e87d39ed02ea73cb60efd2a0a6e87584a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Mon, 23 Sep 2024 09:20:34 +0700 Subject: [PATCH 84/88] Dump logs for the deepgram issue detection --- backend/routers/transcribe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index b001bcf58..dd238e938 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -205,6 +205,7 @@ def stream_audio(audio_buffer): file_path, duration = None, 0 if language == 'en' and (codec == 'opus' or codec == 'pcm16') and include_speech_profile: file_path = get_profile_audio_if_exists(uid) + print(f'deepgram-obns3: file_path {file_path}') duration = AudioSegment.from_wav(file_path).duration_seconds + 5 if file_path else 0 # Deepgram @@ -219,6 +220,7 @@ def stream_audio(audio_buffer): stream_transcript, speech_profile_stream_id, language, sample_rate, deepgram_codec_value, channels ) + print(f'deepgram-obns3: send_initial_file_path > deepgram_socket {deepgram_socket}') await send_initial_file_path(file_path, deepgram_socket) elif stt_service == STTService.soniox: soniox_socket = await process_audio_soniox( @@ -292,6 +294,8 @@ async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_sock print("WebSocket disconnected") except Exception as e: print(f'Could not process audio: error {e}') + print(f'deepgram-obns3: receive_audio > dg_socket1 {dg_socket1}') + print(f'deepgram-obns3: receive_audio > dg_socket2 {dg_socket2}') websocket_close_code = 1011 finally: websocket_active = False From c51616feb1a30aff3560485459c4ea379da65494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Mon, 23 Sep 2024 10:28:05 +0700 Subject: [PATCH 85/88] Use socket send async func to wrap the various socket send function when send initial file. Update deepgram connect func, use .websocket instead of live --- backend/routers/transcribe.py | 6 ++++-- backend/utils/stt/streaming.py | 38 ++++++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index dd238e938..f4e05bbde 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -221,7 +221,9 @@ def stream_audio(audio_buffer): ) print(f'deepgram-obns3: send_initial_file_path > deepgram_socket {deepgram_socket}') - await send_initial_file_path(file_path, deepgram_socket) + async def deepgram_socket_send(data): + return deepgram_socket.send(data) + await send_initial_file_path(file_path, deepgram_socket_send) elif stt_service == STTService.soniox: soniox_socket = await process_audio_soniox( stream_transcript, speech_profile_stream_id, language, uid if include_speech_profile else None @@ -231,7 +233,7 @@ def stream_audio(audio_buffer): stream_transcript, speech_profile_stream_id, language, preseconds=duration ) if duration: - await send_initial_file_path(file_path, speechmatics_socket) + await send_initial_file_path(file_path, speechmatics_socket.send) print('speech_profile speechmatics duration', duration) except Exception as e: diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index e111c825c..8911f9174 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -4,7 +4,7 @@ from typing import List import websockets -from deepgram import DeepgramClient, DeepgramClientOptions, LiveTranscriptionEvents +from deepgram import DeepgramClient, DeepgramClientOptions, LiveTranscriptionEvents, ListenWebSocketClient from deepgram.clients.live.v1 import LiveOptions import database.notifications as notification_db @@ -61,7 +61,7 @@ # return segments -async def send_initial_file_path(file_path: str, transcript_socket): +async def send_initial_file_path(file_path: str, transcript_socket_async_send): print('send_initial_file_path') start = time.time() # Reading and sending in chunks @@ -71,7 +71,7 @@ async def send_initial_file_path(file_path: str, transcript_socket): if not chunk: break # print('Uploading', len(chunk)) - await transcript_socket.send(bytes(chunk)) + await transcript_socket_async_send(bytes(chunk)) await asyncio.sleep(0.0001) # if it takes too long to transcribe print('send_initial_file_path', time.time() - start) @@ -154,9 +154,39 @@ def process_segments(uid: str, segments: list[dict]): def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, codec: str, channels: int): # 'wss://api.deepgram.com/v1/listen?encoding=linear16&sample_rate=8000&language=$recordingsLanguage&model=nova-2-general&no_delay=true&endpointing=100&interim_results=false&smart_format=true&diarize=true' try: - dg_connection = deepgram.listen.live.v("1") + dg_connection = deepgram.listen.websocket.v("1") dg_connection.on(LiveTranscriptionEvents.Transcript, on_message) dg_connection.on(LiveTranscriptionEvents.Error, on_error) + + def on_open(self, open, **kwargs): + print("Connection Open") + + def on_metadata(self, metadata, **kwargs): + print(f"Metadata: {metadata}") + + def on_speech_started(self, speech_started, **kwargs): + print("Speech Started") + + def on_utterance_end(self, utterance_end, **kwargs): + print("Utterance End") + global is_finals + if len(is_finals) > 0: + utterance = " ".join(is_finals) + print(f"Utterance End: {utterance}") + is_finals = [] + + def on_close(self, close, **kwargs): + print("Connection Closed") + + def on_unhandled(self, unhandled, **kwargs): + print(f"Unhandled Websocket Message: {unhandled}") + + dg_connection.on(LiveTranscriptionEvents.Open, on_open) + dg_connection.on(LiveTranscriptionEvents.Metadata, on_metadata) + dg_connection.on(LiveTranscriptionEvents.SpeechStarted, on_speech_started) + dg_connection.on(LiveTranscriptionEvents.UtteranceEnd, on_utterance_end) + dg_connection.on(LiveTranscriptionEvents.Close, on_close) + dg_connection.on(LiveTranscriptionEvents.Unhandled, on_unhandled) options = LiveOptions( punctuate=True, no_delay=True, From 32dd85dea7f2b50463114d48d976d7d5c4080d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Mon, 23 Sep 2024 11:54:01 +0700 Subject: [PATCH 86/88] Bump version to 1.0.39+137 --- app/pubspec.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/pubspec.yaml b/app/pubspec.yaml index d10b6cedb..63512597c 100644 --- a/app/pubspec.yaml +++ b/app/pubspec.yaml @@ -3,7 +3,7 @@ description: A new Flutter project. publish_to: 'none' # Remove this line if you wish to publish to pub.dev -version: 1.0.39+136 +version: 1.0.39+137 environment: sdk: ">=3.0.0 <4.0.0" From b0bf2bd01e1b0a8783242e781a9c374dec677383 Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:40:48 +0530 Subject: [PATCH 87/88] add scopes to GoogleSignIn --- app/lib/backend/auth.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/lib/backend/auth.dart b/app/lib/backend/auth.dart index d1ed00dab..837b97df2 100644 --- a/app/lib/backend/auth.dart +++ b/app/lib/backend/auth.dart @@ -110,7 +110,7 @@ Future signInWithGoogle() async { try { print('Signing in with Google'); // Trigger the authentication flow - final GoogleSignInAccount? googleUser = await GoogleSignIn().signIn(); + final GoogleSignInAccount? googleUser = await GoogleSignIn(scopes: ['profile', 'email']).signIn(); print('Google User: $googleUser'); // Obtain the auth details from the request final GoogleSignInAuthentication? googleAuth = await googleUser?.authentication; From f0ea1bfeacafe8a6b05771b2838a542d6762698c Mon Sep 17 00:00:00 2001 From: Mohammed Mohsin <59914433+mdmohsin7@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:43:10 +0530 Subject: [PATCH 88/88] bump version to 1.0.39+138 for release --- app/pubspec.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/pubspec.yaml b/app/pubspec.yaml index 63512597c..4c568deb1 100644 --- a/app/pubspec.yaml +++ b/app/pubspec.yaml @@ -3,7 +3,7 @@ description: A new Flutter project. publish_to: 'none' # Remove this line if you wish to publish to pub.dev -version: 1.0.39+137 +version: 1.0.39+138 environment: sdk: ">=3.0.0 <4.0.0"