From a22ddd0fa99539a09d0e3206bcdc754c34fcef3f Mon Sep 17 00:00:00 2001 From: Tom Potter Date: Wed, 10 Jul 2024 16:37:46 +0100 Subject: [PATCH] Corrected stutter detection output, improved documentation, and handled processing error --- README.md | 23 +++-- stutter_detection/MaxVQAVideoDetector.py | 5 +- stutter_detection/StutterDetection.py | 118 ++++++++++++++--------- 3 files changed, 93 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index f00d0d7..5fa4626 100644 --- a/README.md +++ b/README.md @@ -165,13 +165,15 @@ wget https://github.com/VQAssessment/DOVER/releases/download/v0.1.0/DOVER.pth -P ## Running -* Run inference on directory at **PATH**: `python StutterDetection.py PATH` -* This will timestamps of the any stuttering found in the audio or video files. +* Run inference on directory or video/audio file at **PATH**: `python StutterDetection.py PATH` +* This will output a plot of the "motion fluency" over the course of the video (low fluency may indicate stuttering events) and/or a plot of audio stutter times detected in the waveform. ### General CLI ``` -usage: StutterDetection.py [-h] [-f FRAMES] [-e EPOCHS] [-c] [-na] [-nv] [-t] directory +usage: StutterDetection.py [-h] [-na] [-nv] [-c] [-t] [-i] [-f FRAMES] [-e EPOCHS] + [-d DEVICE] + directory Run audio and video stutter detection algorithms over local AV segments. @@ -180,10 +182,19 @@ positional arguments: options: -h, --help show this help message and exit - -na, --no-audio - -nv, --no-video - -c, --clean-video + -na, --no-audio Do not perform stutter detection on the audio track + -nv, --no-video Do not perform stutter detection on the video track + -c, --clean-video Testing on clean stutter-free videos (for experimentation) -t, --true-timestamps + Plot known stutter times on the output graph, specified in + 'true-stutter-timestamps.json + -i, --time-indexed-files + Label batch of detections over video segments with their + time range (from filename) -f FRAMES, --frames FRAMES + Number of frames to downsample video to -e EPOCHS, --epochs EPOCHS + Number of times to repeat inference per video + -d DEVICE, --device DEVICE + Specify processing hardware ``` diff --git a/stutter_detection/MaxVQAVideoDetector.py b/stutter_detection/MaxVQAVideoDetector.py index 6cfd4b2..6083ecb 100644 --- a/stutter_detection/MaxVQAVideoDetector.py +++ b/stutter_detection/MaxVQAVideoDetector.py @@ -53,7 +53,10 @@ def encode_text_prompts(prompts, tokenizer, model, device='cpu'): def setup_models(text_prompts, opt, aesthetic_clip_len, technical_num_clips, device='cpu', use_aesthetic_features=False): # Initialize fast-vqa encoder fast_vqa_encoder = DOVER(**opt["model"]["args"]).to(device) - fast_vqa_encoder.load_state_dict(torch.load("ExplainableVQA/DOVER/pretrained_weights/DOVER.pth", map_location=device), strict=False) + fast_vqa_encoder.load_state_dict( + torch.load("ExplainableVQA/DOVER/pretrained_weights/DOVER.pth", map_location=device), + strict=False + ) # Initialize CLIP model clip_model, _, _ = open_clip.create_model_and_transforms("RN50", pretrained="openai") diff --git a/stutter_detection/StutterDetection.py b/stutter_detection/StutterDetection.py index 534e4c3..7539cb4 100644 --- a/stutter_detection/StutterDetection.py +++ b/stutter_detection/StutterDetection.py @@ -3,6 +3,7 @@ import json import math import glob +import pathlib import argparse import numpy as np from scipy.io import wavfile @@ -98,16 +99,17 @@ def process(self, directory_path, truth=None, audio_detection=True, video_detect self.video_detection_results = np.append(self.video_detection_results, results[:, :math.ceil(results.shape[1] * 0.9)], axis=1) self.video_segment_index += 1 - # Plot global video detection results over all clips in timeline - global_start_time = datetime.strptime(video_segment_paths[0].split('/')[-1].replace('.mp4', '').split('_')[1], '%H:%M:%S.%f') - global_end_time = timestamps[-1] - print(f"Full timeline: {global_start_time.strftime('%H:%M:%S.%f')} => {global_end_time.strftime('%H:%M:%S.%f')}") - self.plot_local_vqa( - self.video_detection_results, - true_time_labels=truth, - startpoint=global_start_time, endpoint=global_end_time, - output_file="motion-timeline.png" - ) + # If recording timed segments, plot global video detection results over all clips in timeline + if time_indexed_files: + global_start_time = datetime.strptime(video_segment_paths[0].split('/')[-1].replace('.mp4', '').split('_')[1], '%H:%M:%S.%f') + global_end_time = timestamps[-1] + print(f"Full timeline: {global_start_time.strftime('%H:%M:%S.%f')} => {global_end_time.strftime('%H:%M:%S.%f')}") + self.plot_local_vqa( + self.video_detection_results, + true_time_labels=truth, + startpoint=global_start_time, endpoint=global_end_time, + output_file="motion-timeline.png" + ) def get_local_paths(self, audio_detection=True, video_detection=True, dir="./data/"): sort_by_index = lambda path: int(path.split('/')[-1].split('_')[0][3:]) @@ -182,7 +184,7 @@ def audio_detection(self, audio_content, time_indexed_audio=False, detect_gaps=T print() return {"gaps": detected_audio_gaps, "clicks": detected_audio_clicks} - def plot_audio(self, audio_content, gap_times, click_times, startpoint, endpoint): + def plot_audio(self, audio_content, gap_times, click_times, startpoint, endpoint, output_file=''): # Setup plt.rcParams['agg.path.chunksize'] = 1000 fig, axs = plt.subplots(1, figsize=(20, 10), tight_layout=True) @@ -224,7 +226,18 @@ def plot_audio(self, audio_content, gap_times, click_times, startpoint, endpoint plt.ylabel("Audio Sample Amplitude", fontsize=14) plt.title(f"Audio Defect Detection: Segment {self.audio_segment_index} ({time_x[0].strftime('%H:%M:%S')} => {time_x[-1].strftime('%H:%M:%S')})) \n", fontsize=18) plt.legend(loc=1, fontsize=14) - fig.savefig(f"output/plots/audio-plot-{self.audio_segment_index}.png") + + # Save plot to file + output_path = "output/plots/" + pathlib.Path(output_path).mkdir(parents=True, exist_ok=True) + + if output_file == '': + output_path = os.path.join(output_path, f"audio-plot-{self.audio_segment_index}.png") + else: + output_path = os.path.join(output_path, output_file) + + print(f" * Audio plot generated : {output_path}") + fig.savefig(output_path) plt.close(fig) def video_detection(self, video_content, time_indexed_video=False, plot=False, start_time=0, end_time=0, epochs=1): @@ -262,24 +275,25 @@ def video_detection(self, video_content, time_indexed_video=False, plot=False, s print() return output - def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoint=0, output_file=''): + def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoint=0, plot_motion_only=True, output_file=''): # Metrics & figure setup - # priority_metrics = [7, 9, 11, 13, 14] - # titles = { - # "A": "Sharpness", - # "B": "Noise", - # "C": "Flicker", - # "D": "Compression artefacts", - # "E": "Motion fluency" - # } - # fig, axes = plt.subplot_mosaic("AB;CD;EE", sharex=True, sharey=True, figsize=(12, 9), tight_layout=True) - - priority_metrics = [14] - titles = { - "A": "Motion fluency" - } - plot_values = vqa_values[priority_metrics] - fig, axes = plt.subplot_mosaic("A", sharex=True, sharey=True, figsize=(12, 6), tight_layout=True) + if plot_motion_only: + priority_metrics = [14] + titles = { + "A": "Motion fluency" + } + plot_values = vqa_values[priority_metrics] + fig, axes = plt.subplot_mosaic("A", sharex=True, sharey=True, figsize=(12, 6), tight_layout=True) + else: + priority_metrics = [7, 9, 11, 13, 14] + titles = { + "A": "Sharpness", + "B": "Noise", + "C": "Flicker", + "D": "Compression artefacts", + "E": "Motion fluency" + } + fig, axes = plt.subplot_mosaic("AB;CD;EE", sharex=True, sharey=True, figsize=(12, 9), tight_layout=True) colours = cycle(mcolors.TABLEAU_COLORS) @@ -329,7 +343,7 @@ def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoi labels=[t.strftime('%H:%M:%S') for t in time_x[::num_ticks]] ) else: - fig.suptitle(f"MaxVQA Video Defect Detection{f': Segment {self.video_segment_index}' if output_file == '' else ''}", fontsize=16) + fig.suptitle("MaxVQA Video Defect Detection", fontsize=16) fig.supxlabel("Capture Frame") fig.supylabel("Absolute score (0-1, bad-good)") @@ -343,13 +357,16 @@ def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoi ax.label_outer() # Save plot to file + output_path = "output/plots/" + pathlib.Path(output_path).mkdir(parents=True, exist_ok=True) + if output_file == '': - # fig.savefig(f"output/plots/video-plot-{self.video_segment_index}.png") - fig.savefig(f"output/plots/motion-plot-{self.video_segment_index}.png") + output_path = os.path.join(output_path, f"motion-plot-{self.video_segment_index}.png") else: - fig.savefig(f"output/plots/{output_file}") + output_path = os.path.join(output_path, output_file) - print(f" * Plot generated : {f'video-plot-{self.video_segment_index}.png' if output_file == '' else output_file}") + print(f" * Video plot generated : {output_path}") + fig.savefig(output_path) plt.close(fig) @@ -361,43 +378,52 @@ def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoi ) parser.add_argument("directory") - parser.add_argument('-na', '--no-audio', action='store_false', default=True) - parser.add_argument('-nv', '--no-video', action='store_false', default=True) - parser.add_argument('-c', '--clean-video', action='store_true', default=False) - parser.add_argument('-t', '--true-timestamps', action='store_true', default=False) - parser.add_argument('-f', '--frames', type=int, default=256) - parser.add_argument('-e', '--epochs', type=int, default=3) + parser.add_argument('-na', '--no-audio', action='store_false', default=True, help="Do not perform stutter detection on the audio track") + parser.add_argument('-nv', '--no-video', action='store_false', default=True, help="Do not perform stutter detection on the video track") + parser.add_argument('-c', '--clean-video', action='store_true', default=False, help="Testing on clean stutter-free videos (for experimentation)") + parser.add_argument('-t', '--true-timestamps', action='store_true', default=False, help="Plot known stutter times on the output graph, specified in 'true-stutter-timestamps.json") + parser.add_argument('-i', '--time-indexed-files', action='store_true', default=False, help="Label batch of detections over video segments with their time range (from filename)") + parser.add_argument('-f', '--frames', type=int, default=256, help="Number of frames to downsample video to") + parser.add_argument('-e', '--epochs', type=int, default=1, help="Number of times to repeat inference per video") + parser.add_argument('-d', '--device', type=str, default='cpu', help="Specify processing hardware") # Decode input parameters to toggle between cameras, microphones, and setup mode. args = parser.parse_args() path = args.directory frames = args.frames epochs = args.epochs + device = args.device stutter = not args.clean_video audio_on = args.no_audio video_on = args.no_video plot_true_timestamps = args.true_timestamps + index_by_file_timestamp = args.time_indexed_files - detector = StutterDetection(video_downsample_frames=frames, device='cpu') + detector = StutterDetection(video_downsample_frames=frames, device=device) if path.endswith(".mp4") or path.endswith(".wav"): detector.process( directory_path=path, - time_indexed_files=True, + time_indexed_files=index_by_file_timestamp, inference_epochs=epochs, audio_detection=audio_on, video_detection=video_on ) else: if stutter and plot_true_timestamps: - with open(f"{path}/stutter/true-stutter-timestamps.json", 'r') as f: + timestamps_file = f"{path}/stutter/true-stutter-timestamps.json" + if not os.path.isfile(timestamps_file): + print(f"Error: no true timestamps file found but 'plot_true_timestamps' enabled. Checked location: {timestamps_file}") + exit(1) + + with open(timestamps_file, 'r') as f: json_data = json.load(f) true_timestamps_json = json_data["timestamps"] detector.process( directory_path=f"{path}/stutter/", truth=true_timestamps_json, - time_indexed_files=True, + time_indexed_files=index_by_file_timestamp, inference_epochs=epochs, audio_detection=audio_on, video_detection=video_on @@ -405,7 +431,7 @@ def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoi elif not stutter: detector.process( directory_path=f"{path}/original/", - time_indexed_files=True, + time_indexed_files=index_by_file_timestamp, inference_epochs=epochs, audio_detection=audio_on, video_detection=video_on @@ -413,7 +439,7 @@ def plot_local_vqa(self, vqa_values, true_time_labels=None, startpoint=0, endpoi else: detector.process( directory_path=path, - time_indexed_files=True, + time_indexed_files=index_by_file_timestamp, inference_epochs=epochs, audio_detection=audio_on, video_detection=video_on