Skip to content

Commit

Permalink
feat: Audio narration (#673)
Browse files Browse the repository at this point in the history
* feat: Add audio narration feature while recording

* feat: Remove implicit scrubbing in display_event function and recursively convert reqd properties to str

* feat: Add transcribed text to dashboard visualisation

* feat: Use recording id as foreign key, and add interrupt signal handler in audio recording process

* feat: Check if the lock is stale when acquiring locks

* refactor: Convert database lock path to a constant in config file

---------

Co-authored-by: Richard Abrich <[email protected]>
  • Loading branch information
KIRA009 and abrichr authored Jun 4, 2024
1 parent 8b4d9ef commit 1e11906
Show file tree
Hide file tree
Showing 14 changed files with 514 additions and 117 deletions.
53 changes: 53 additions & 0 deletions openadapt/alembic/versions/98c8851a5321_add_audio_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""add_audio_info
Revision ID: 98c8851a5321
Revises: d714cc86fce8
Create Date: 2024-05-29 16:56:25.832333
"""
from alembic import op
import sqlalchemy as sa

import openadapt

# revision identifiers, used by Alembic.
revision = "98c8851a5321"
down_revision = "d714cc86fce8"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"audio_info",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"timestamp",
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("flac_data", sa.LargeBinary(), nullable=True),
sa.Column("transcribed_text", sa.String(), nullable=True),
sa.Column(
"recording_timestamp",
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("recording_id", sa.Integer(), nullable=True),
sa.Column("sample_rate", sa.Integer(), nullable=True),
sa.Column("words_with_timestamps", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["recording_id"],
["recording.id"],
name=op.f("fk_audio_info_recording_id_recording"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audio_info")),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("audio_info")
# ### end Alembic commands ###
31 changes: 30 additions & 1 deletion openadapt/app/dashboard/api/recordings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""API endpoints for recordings."""

import json

from fastapi import APIRouter, WebSocket
from loguru import logger

Expand Down Expand Up @@ -80,6 +82,22 @@ async def get_recording_detail(websocket: WebSocket, recording_id: int) -> None:
{"type": "num_events", "value": len(action_events)}
)

try:
# TODO: change to use recording_id once scrubbing PR is merged
audio_info = crud.get_audio_info(session, recording.timestamp)[0]
words_with_timestamps = json.loads(audio_info.words_with_timestamps)
words_with_timestamps = [
{
"word": word["word"],
"start": word["start"] + action_events[0].timestamp,
"end": word["end"] + action_events[0].timestamp,
}
for word in words_with_timestamps
]
except IndexError:
words_with_timestamps = []
word_index = 0

def convert_to_str(event_dict: dict) -> dict:
"""Convert the keys to strings."""
if "key" in event_dict:
Expand All @@ -104,7 +122,18 @@ def convert_to_str(event_dict: dict) -> dict:
width, height = 0, 0
event_dict["screenshot"] = image
event_dict["dimensions"] = {"width": width, "height": height}

words = []
# each word in words_with_timestamp is a dict of word, start, end
# we want to add the word to the event_dict if the start is
# before the event timestamp
while (
word_index < len(words_with_timestamps)
and words_with_timestamps[word_index]["start"]
< event_dict["timestamp"]
):
words.append(words_with_timestamps[word_index]["word"])
word_index += 1
event_dict["words"] = words
convert_to_str(event_dict)
await websocket.send_json({"type": "action_event", "value": event_dict})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ export const ActionEvent = ({
<TableCellWithBorder>{event.parent_id}</TableCellWithBorder>
</TableRowWithBorder>
)}
{event.words && event.words.length > 0 && (
<TableRowWithBorder>
<TableCellWithBorder>transcription</TableCellWithBorder>
<TableCellWithBorder>{event.words.join(' ')}</TableCellWithBorder>
</TableRowWithBorder>
)}
<TableRowWithBorder>
<TableCellWithBorder>children</TableCellWithBorder>
<TableCellWithBorder>{event.children?.length || 0}</TableCellWithBorder>
Expand Down
1 change: 1 addition & 0 deletions openadapt/app/dashboard/types/action-event.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ export type ActionEvent = {
mask: string | null;
dimensions?: { width: number, height: number };
children?: ActionEvent[];
words?: string[];
}
4 changes: 0 additions & 4 deletions openadapt/app/tray.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ def __init__(self) -> None:

self.app.setQuitOnLastWindowClosed(False)

# since the lock is a file, delete it when starting the app so that
# new instances can start even if the previous one crashed
crud.release_db_lock(raise_exception=False)

# currently required for pyqttoast
# TODO: remove once https://github.com/niklashenning/pyqt-toast/issues/9
# is addressed
Expand Down
1 change: 1 addition & 0 deletions openadapt/config.defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
"REPLAY_STRIP_ELEMENT_STATE": true,
"RECORD_VIDEO": true,
"RECORD_AUDIO": true,
"RECORD_FULL_VIDEO": false,
"RECORD_IMAGES": false,
"LOG_MEMORY": false,
Expand Down
2 changes: 2 additions & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PERFORMANCE_PLOTS_DIR_PATH = (DATA_DIR_PATH / "performance").absolute()
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"

STOP_STRS = [
"oa.stop",
Expand Down Expand Up @@ -136,6 +137,7 @@ class SegmentationAdapter(str, Enum):
RECORD_WINDOW_DATA: bool = False
RECORD_READ_ACTIVE_ELEMENT_STATE: bool = False
RECORD_VIDEO: bool
RECORD_AUDIO: bool
# if false, only write video events corresponding to screenshots
RECORD_FULL_VIDEO: bool
RECORD_IMAGES: bool
Expand Down
70 changes: 64 additions & 6 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from loguru import logger
from sqlalchemy.orm import Session as SaSession
import psutil
import sqlalchemy as sa

from openadapt import utils
from openadapt.config import DATA_DIR_PATH, config
from openadapt.config import DATABASE_LOCK_FILE_PATH, config
from openadapt.db.db import Session, get_read_only_session_maker
from openadapt.models import (
ActionEvent,
AudioInfo,
MemoryStat,
PerformanceStat,
Recording,
Expand Down Expand Up @@ -618,6 +620,56 @@ def update_video_start_time(
)


def insert_audio_info(
session: SaSession,
audio_data: bytes,
transcribed_text: str,
recording: Recording,
timestamp: float,
sample_rate: int,
word_list: list,
) -> None:
"""Create an AudioInfo entry in the database.
Args:
session (sa.orm.Session): The database session.
audio_data (bytes): The audio data.
transcribed_text (str): The transcribed text.
recording (Recording): The recording object.
timestamp (float): The timestamp of the audio.
sample_rate (int): The sample rate of the audio.
word_list (list): A list of words with timestamps.
"""
audio_info = AudioInfo(
flac_data=audio_data,
transcribed_text=transcribed_text,
recording_timestamp=recording.timestamp,
recording_id=recording.id,
timestamp=timestamp,
sample_rate=sample_rate,
words_with_timestamps=json.dumps(word_list),
)
session.add(audio_info)
session.commit()


# TODO: change to use recording_id once scrubbing PR is merged
def get_audio_info(
session: SaSession,
recording_timestamp: float,
) -> list[AudioInfo]:
"""Get the audio info for a given recording.
Args:
session (sa.orm.Session): The database session.
recording_timestamp (float): The timestamp of the recording.
Returns:
list[AudioInfo]: A list of audio info for the recording.
"""
return _get(session, AudioInfo, recording_timestamp)


def post_process_events(session: SaSession, recording: Recording) -> None:
"""Post-process events.
Expand Down Expand Up @@ -764,11 +816,17 @@ def acquire_db_lock(timeout: int = 60) -> bool:
if timeout > 0 and time.time() - start > timeout:
logger.error("Failed to acquire database lock.")
return False
if os.path.exists(DATA_DIR_PATH / "database.lock"):
logger.info("Database is locked. Waiting...")
time.sleep(1)
if os.path.exists(DATABASE_LOCK_FILE_PATH):
with open(DATABASE_LOCK_FILE_PATH, "r") as lock_file:
lock_info = json.load(lock_file)
# check if the process is still running
if psutil.pid_exists(lock_info["pid"]):
logger.info("Database is locked. Waiting...")
time.sleep(1)
else:
release_db_lock(raise_exception=False)
else:
with open(DATA_DIR_PATH / "database.lock", "w") as lock_file:
with open(DATABASE_LOCK_FILE_PATH, "w") as lock_file:
lock_file.write(json.dumps({"pid": os.getpid(), "time": time.time()}))
logger.info("Database lock acquired.")
break
Expand All @@ -778,7 +836,7 @@ def acquire_db_lock(timeout: int = 60) -> bool:
def release_db_lock(raise_exception: bool = True) -> None:
"""Release the database lock."""
try:
os.remove(DATA_DIR_PATH / "database.lock")
os.remove(DATABASE_LOCK_FILE_PATH)
except Exception as e:
if raise_exception:
logger.error("Failed to release database lock.")
Expand Down
18 changes: 18 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Recording(db.Base):
"ScrubbedRecording",
back_populates="recording",
)
audio_info = sa.orm.relationship("AudioInfo", back_populates="recording")

_processed_action_events = None

Expand Down Expand Up @@ -723,6 +724,23 @@ def convert_png_to_binary(self, image: Image.Image) -> bytes:
return buffer.getvalue()


class AudioInfo(db.Base):
"""Class representing the audio from a recording in the database."""

__tablename__ = "audio_info"

id = sa.Column(sa.Integer, primary_key=True)
timestamp = sa.Column(ForceFloat)
flac_data = sa.Column(sa.LargeBinary)
transcribed_text = sa.Column(sa.String)
recording_timestamp = sa.Column(ForceFloat)
recording_id = sa.Column(sa.ForeignKey("recording.id"))
sample_rate = sa.Column(sa.Integer)
words_with_timestamps = sa.Column(sa.Text)

recording = sa.orm.relationship("Recording", back_populates="audio_info")


class PerformanceStat(db.Base):
"""Class representing a performance statistic in the database."""

Expand Down
Loading

0 comments on commit 1e11906

Please sign in to comment.