From 76433544a6a01739dec33fbac43d4c81c520aae5 Mon Sep 17 00:00:00 2001 From: Keunwoo Choi Date: Thu, 18 Mar 2021 18:13:11 -0400 Subject: [PATCH] Kapre 0.3.5 (#122) * Feat/tflite compatibility (#118) * update tensorflow requirements to have optional gpu version * Update README.md * Update README.md * update CI * tflite compatible STFT layer * tflite compatible STFT layer * magnitude layer works out shape for itself. * tflite phase support * Approximate phase and tests * undo some black modifications * undo some black modifications * undo some black modifications * undo some black modifications * remove batch index * right padding only for stft, improved unit test * update doc string * update doc string * remove extra requirement * remove extra requirement * add extra req * remove extra req * remove extra req * remove extra req * remove extra req * remove extra req * remove extra req * remove extra req * update readme * update readme * update docstring * update docstrings * tflite compatible layers as a subpackage * Wrap comment in """ * Revert to simple modules * Dont need to reinitialise * Fix formatting * Revert some `black` changes * Docstring update * docstring update Co-authored-by: Paul Kendrick * magnitude_to_decibel support for float16+float64 (#120) * magnitude_to_decibel support for float16+float64 The function kapre.backend.magnitude_to_decibel with the default arguments throws a TensorFlow type error if `x` is not `float32`. We can fix this by casting `amin` to `x`'s dtype. * Tests for magnitude_to_decibel dtype Floating point errors with `float16` require use to assert with a larger NumPy `rtol`, but otherwise it works. * add _tflite to the doc * update doc setup * add doc utility script * fix install requirement for docs * rtd * re-travis Co-authored-by: Paul Kendrick Co-authored-by: Paul Kendrick Co-authored-by: James Mishra --- README.md | 2 +- docs/_static/css/custom.css | 7 +- docs/conf.py | 21 +++ docs/index.rst | 1 + docs/release_note.rst | 4 + docs/requirements.txt | 3 +- docs/time_frequency_tflite.rst | 5 + kapre/__init__.py | 3 +- kapre/augmentation.py | 8 +- kapre/backend.py | 1 + kapre/composed.py | 6 +- kapre/signal.py | 44 ++++-- kapre/tflite_compatible_stft.py | 247 ++++++++++++++++++++++++++++++++ kapre/time_frequency.py | 57 ++++++-- kapre/time_frequency_tflite.py | 181 +++++++++++++++++++++++ setup.py | 4 +- tests/test_augmentation.py | 6 +- tests/test_backend.py | 13 +- tests/test_time_frequency.py | 135 ++++++++++++++++- tests/utils.py | 61 +++++++- 20 files changed, 760 insertions(+), 49 deletions(-) create mode 100644 docs/time_frequency_tflite.rst create mode 100644 kapre/tflite_compatible_stft.py create mode 100644 kapre/time_frequency_tflite.py diff --git a/README.md b/README.md index f91a874..e1b2ead 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Kapre Keras Audio Preprocessors - compute STFT, ISTFT, Melspectrogram, and others on GPU real-time. - + Tested on Python 3.6 and 3.7 ## Why Kapre? diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css index b7b207a..ec8d772 100644 --- a/docs/_static/css/custom.css +++ b/docs/_static/css/custom.css @@ -1,3 +1,8 @@ div.wy-nav-content { - max-width: 1200px; + max-width: 1000px; +} + +code.literal { + color: #404040 !important; + background-color: #fbfbfb !important; } diff --git a/docs/conf.py b/docs/conf.py index cd17450..a6a50f2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,6 +15,7 @@ import sys sys.path.insert(0, os.path.abspath('../')) import sphinx_rtd_theme + autodoc_mock_imports = ['tensorflow', 'librosa', 'numpy'] autodoc_member_order = 'bysource' @@ -40,8 +41,18 @@ "sphinx.ext.napoleon", # "sphinx.ext.autosummary", "sphinx.ext.viewcode", # source linkage + "sphinxcontrib.inlinesyntaxhighlight" # inline code highlight ] +# https://stackoverflow.com/questions/21591107/sphinx-inline-code-highlight +# use language set by highlight directive if no language is set by role +inline_highlight_respect_highlight = True +# use language set by highlight directive if no role is set +inline_highlight_literals = True + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + # autosummary_generate = True # autoapi_type = 'python' @@ -68,4 +79,14 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] +html_css_files = [ + 'css/custom.css', +] + + +def setup(app): + app.add_stylesheet("css/custom.css") + + master_doc = 'index' + diff --git a/docs/index.rst b/docs/index.rst index 416921e..b6cf535 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -92,6 +92,7 @@ Visit `github.com/keunwoochoi/kapre `_ and signal composed backend + time_frequency_tflite .. toctree:: :hidden: diff --git a/docs/release_note.rst b/docs/release_note.rst index 7c1f876..a86cd94 100644 --- a/docs/release_note.rst +++ b/docs/release_note.rst @@ -1,6 +1,10 @@ Release Note ^^^^^^^^^^^^ +* 18 March 2021 + - 0.3.5 + - Add `kapre.time_frequency_tflite` which uses tflite for a faster CPU inference. + * 29 Sep 2020 - 0.3.4 - Fix a bug in `kapre.backend.get_window_fn()`. Previously, it only correctly worked with `None` input and an error was raised when non-default value was set for `window_name` in any layer. diff --git a/docs/requirements.txt b/docs/requirements.txt index 6883dc7..41a33da 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ sphinx!=1.3.1 sphinx_rtd_theme -sphinxcontrib-napoleon \ No newline at end of file +sphinxcontrib-napoleon +sphinxcontrib-inlinesyntaxhighlight diff --git a/docs/time_frequency_tflite.rst b/docs/time_frequency_tflite.rst new file mode 100644 index 0000000..7b665b2 --- /dev/null +++ b/docs/time_frequency_tflite.rst @@ -0,0 +1,5 @@ +time_frequency_tflite +^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: kapre.time_frequency_tflite + :members: diff --git a/kapre/__init__.py b/kapre/__init__.py index c0c9b81..14e71c9 100644 --- a/kapre/__init__.py +++ b/kapre/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.4' +__version__ = '0.3.5' VERSION = __version__ from . import composed @@ -6,3 +6,4 @@ from .signal import * from .time_frequency import * +from .time_frequency_tflite import * diff --git a/kapre/augmentation.py b/kapre/augmentation.py index 2edb7d2..09d0a1f 100644 --- a/kapre/augmentation.py +++ b/kapre/augmentation.py @@ -49,7 +49,9 @@ class ChannelSwap(Layer): """ def __init__( - self, data_format='default', **kwargs, + self, + data_format='default', + **kwargs, ): backend.validate_data_format_str(data_format) @@ -94,6 +96,8 @@ def call(self, x, training=None): def get_config(self): config = super(ChannelSwap, self).get_config() config.update( - {'data_format': self.data_format,} + { + 'data_format': self.data_format, + } ) return config diff --git a/kapre/backend.py b/kapre/backend.py index ab64344..2de0a5a 100644 --- a/kapre/backend.py +++ b/kapre/backend.py @@ -110,6 +110,7 @@ def _log10(x): if amin is None: amin = 1e-5 + amin = tf.cast(amin, dtype=x.dtype) log_spec = 10.0 * _log10(tf.math.maximum(x, amin)) log_spec = log_spec - 10.0 * _log10(tf.math.maximum(amin, ref_value)) diff --git a/kapre/composed.py b/kapre/composed.py index 5300a14..fa9848a 100644 --- a/kapre/composed.py +++ b/kapre/composed.py @@ -13,6 +13,9 @@ """ +from tensorflow import keras +from tensorflow.keras import Sequential, Model + from .time_frequency import ( STFT, InverseSTFT, @@ -23,9 +26,6 @@ ConcatenateFrequencyMap, ) from . import backend - -from tensorflow import keras -from tensorflow.keras import Sequential, Model from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR diff --git a/kapre/signal.py b/kapre/signal.py index 7a6566d..65a80e6 100644 --- a/kapre/signal.py +++ b/kapre/signal.py @@ -5,8 +5,9 @@ """ import tensorflow as tf from tensorflow.keras.layers import Layer -from . import backend from tensorflow.keras import backend as K + +from . import backend from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR @@ -23,7 +24,7 @@ class Frame(Layer): pad_end (bool): whether to pad at the end of the signal of there would be a otherwise-discarded partial frame pad_value (int or float): value to use in the padding data_format (str): `channels_first`, `channels_last`, or `default` - **kwargs: + **kwargs: optional keyword args for `tf.keras.layers.Layer()` Example: :: @@ -63,9 +64,8 @@ def call(self, x): x (`Tensor`): batch audio signal in the specified 1D format in initiation. Returns: - (`Tensor`): A framed tensor. The shape is - (batch, time (frames), frame_length, channel) if `channels_last` and - (batch, channel, time (frames), frame_length) if `channels_first`. + (`Tensor`): A framed tensor. The shape is (batch, time (frames), frame_length, channel) if `channels_last`, + or (batch, channel, time (frames), frame_length) if `channels_first`. """ return tf.signal.frame( x, @@ -104,7 +104,7 @@ class Energy(Layer): pad_end (bool): whether to pad at the end of the signal of there would be a otherwise-discarded partial frame pad_value (int or float): value to use in the padding data_format (str): `channels_first`, `channels_last`, or `default` - **kwargs: + **kwargs: optional keyword args for `tf.keras.layers.Layer()` Example: :: @@ -154,9 +154,8 @@ def call(self, x): x (`Tensor`): batch audio signal in the specified 1D format in initiation. Returns: - (`Tensor`): A framed tensor. The shape is - (batch, time (frames), channel) if `channels_last`, and - (batch, channel, time (frames)) if `channels_first`. + (`Tensor`): A framed tensor. The shape is (batch, time (frames), channel) if `channels_last`, or + (batch, channel, time (frames)) if `channels_first`. """ frames = tf.signal.frame( x, @@ -200,6 +199,7 @@ class MuLawEncoding(Layer): Args: quantization_channels (positive int): Number of channels. For 8-bit encoding, use 256. + **kwargs: optional keyword args for `tf.keras.layers.Layer()` Note: Mu-law encoding was originally developed to increase signal-to-noise ratio of signal during transmission. @@ -219,7 +219,9 @@ class MuLawEncoding(Layer): """ def __init__( - self, quantization_channels, **kwargs, + self, + quantization_channels, + **kwargs, ): super(MuLawEncoding, self).__init__(**kwargs) self.quantization_channels = quantization_channels @@ -238,7 +240,9 @@ def call(self, x): def get_config(self): config = super(MuLawEncoding, self).get_config() config.update( - {'quantization_channels': self.quantization_channels,} + { + 'quantization_channels': self.quantization_channels, + } ) return config @@ -251,6 +255,7 @@ class MuLawDecoding(Layer): Args: quantization_channels (positive int): Number of channels. For 8-bit encoding, use 256. + **kwargs: optional keyword args for `tf.keras.layers.Layer()` Example: :: @@ -263,7 +268,9 @@ class MuLawDecoding(Layer): """ def __init__( - self, quantization_channels, **kwargs, + self, + quantization_channels, + **kwargs, ): super(MuLawDecoding, self).__init__(**kwargs) self.quantization_channels = quantization_channels @@ -282,7 +289,9 @@ def call(self, x): def get_config(self): config = super(MuLawDecoding, self).get_config() config.update( - {'quantization_channels': self.quantization_channels,} + { + 'quantization_channels': self.quantization_channels, + } ) return config @@ -303,6 +312,11 @@ class LogmelToMFCC(Layer): As long as all of your data in training / inference / deployment is consistent (i.e., do not mix librosa and kapre MFCC), it'll be fine! + Args: + n_mfccs (int): Number of MFCC + data_format (str): `channels_first`, `channels_last`, or `default` + **kwargs: optional keyword args for `tf.keras.layers.Layer()` + Example: :: @@ -336,8 +350,8 @@ def call(self, log_melgrams): and `(b, ch, time, mel)` if `channels_first`. Returns: - (float `Tensor`): MFCCs. `(batch, time, n_mfccs, ch)` if `channels_last` - and `(batch, ch, time, n_mfccs)` if `channels_first`. + (float `Tensor`): + MFCCs. `(batch, time, n_mfccs, ch)` if `channels_last`, `(batch, ch, time, n_mfccs)` if `channels_first`. """ if self.permutation is not None: # reshape so that last channel == mel log_melgrams = K.permute_dimensions(log_melgrams, pattern=self.permutation) diff --git a/kapre/tflite_compatible_stft.py b/kapre/tflite_compatible_stft.py new file mode 100644 index 0000000..7ab1000 --- /dev/null +++ b/kapre/tflite_compatible_stft.py @@ -0,0 +1,247 @@ +"""Workarounds for missing TFLite support for rfft and stft and tf.signal.frame. + +based on: +https://github.com/tensorflow/magenta/tree/master/magenta/music +as posted in https://github.com/tensorflow/tensorflow/issues/27303 +and https://gist.github.com/padoremu/8288b47ce76e9530eb288d4eec2e0b4d +""" +import math + +import tensorflow as tf +import numpy as np + + +def _rdft_matrix(dft_length): + """Return a precalculated DFT matrix for positive frequencies only + + Args: + dft_length (int) - DFT length + + Returns + rdft_mat (k x n tensor) - precalculated DFT matrix rows are frequencies + columns are samples, k dimension is dft_length // 2 +1 bins long + """ + # freq bins + k = np.arange(0, dft_length // 2 + 1) + # Samples + n = np.arange(0, dft_length) + # complex frequency vector (now normalised to 2 pi) + omega = -1j * 2.0 * np.pi / dft_length * k + # complex phase, compute a matrix of value for the complex phase for each sample + # location (n) and each freq bin (k) outer product If the two vectors have dimensions + # k and n, then their outer product is an k × n matrix + phase = np.outer(omega, n) + # return transposed ready for matrix multiplication + return np.exp(phase).astype(np.complex64).T + + +@tf.function +def _rdft(signal, dft_length): + """DFT for real signals. + Calculates the onesided dft, assuming real signal implies complex congugaqe symetry, + hence only onesided DFT is returned. + + Args: + signal (tensor) signal to transform, assumes that the last dimension is the time dimension + signal can be framed, e.g. (1, 40, 1024) for a single batch of 40 frames of + length 1024 + dft_length (int) - DFT length + + Returns: + spec_real (float32 tensor) - real part of spectrogram, e.g. (1, 40, 513) for a 1024 length dft + spec_imag (float32 tensor) - imag part of spectrogram, e.g. (1, 40, 513) for a 1024 length dft + """ + # calculate the positive frequency atoms, and tell tensorflow this is a constant. + rdft_mat = _rdft_matrix(dft_length) + + # tflite doest support complex types so split into real and imaginary: + rdft_mat_real = tf.constant(np.real(rdft_mat)) + rdft_mat_imag = tf.constant(np.imag(rdft_mat)) + + frame_length = tf.shape(signal)[-1] + # Right-padding, in case the frame length and DFT lenght are different, + # pad the signal on the right hand side of the frame + pad_values = tf.concat( + [tf.zeros([tf.rank(signal) - 1, 2], tf.int32), [[0, dft_length - frame_length]]], axis=0 + ) + + signal_padded = tf.pad(signal, pad_values) + + # matrix multiplying real and imag seperately is faster than using complex types. + spec_real = tf.matmul(signal_padded, rdft_mat_real) + spec_imag = tf.matmul(signal_padded, rdft_mat_imag) + spectrogram = tf.stack([spec_real, spec_imag], axis=-1) + + return spectrogram + + +def fixed_frame(signal, frame_length, frame_step): + """tflite-compatible tf.signal.frame for fixed-size input. + + Args: + signal: Tensor containing signal(s). + frame_length: Number of samples to put in each frame. + frame_step: Sample advance between successive frames. + + Returns: + A new tensor where the last axis (or first, if first_axis) of input + signal has been replaced by a (num_frames, frame_length) array of individual + frames where each frame is drawn frame_step samples after the previous one. + + Raises: + ValueError: if signal has an undefined axis length. This routine only + supports framing of signals whose shape is fixed at graph-build time. + """ + signal_shape = list(signal.shape) + length_samples = signal_shape[-1] + + if length_samples <= 0: + raise ValueError("fixed framing requires predefined constant signal length") + # the number of whole frames + num_frames = max(0, 1 + (length_samples - frame_length) // frame_step) + + # define the output_shape, if we recieve a None dimension, replace with 1 + outer_dimensions = [dim if dim else 1 for dim in signal_shape[:-1]] + # outer_dimensions = signal_shape[:-1] + output_shape = outer_dimensions + [num_frames, frame_length] + + # Currently tflite's gather only supports axis==0, but that may still + # work if we want the last of 1 axes. + gather_axis = len(outer_dimensions) + + # subframe length is the largest int that as a common divisor of the frame + # length and hop length. We will slice the signal up into these subframes + # in order to then construct the frames. + subframe_length = math.gcd(frame_length, frame_step) + subframes_per_frame = frame_length // subframe_length + subframes_per_hop = frame_step // subframe_length + num_subframes = length_samples // subframe_length + + # define the subframe shape and the trimmed audio length, removeing any unused + # excess audio, so subframe fit exactly. + subframe_shape = outer_dimensions + [num_subframes, subframe_length] + trimmed_input_size = outer_dimensions + [num_subframes * subframe_length] + # slice up the audio into subframes + subframes = tf.reshape( + tf.slice(signal, begin=np.zeros(len(signal_shape), np.int32), size=trimmed_input_size), + subframe_shape, + ) + + # frame_selector is a [num_frames, subframes_per_frame] tensor + # that indexes into the appropriate frame in subframes. For example: + # [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]] + frame_selector = np.reshape(np.arange(num_frames) * subframes_per_hop, [num_frames, 1]) + + # subframe_selector is a [num_frames, subframes_per_frame] tensor + # that indexes into the appropriate subframe within a frame. For example: + # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] + subframe_selector = np.reshape(np.arange(subframes_per_frame), [1, subframes_per_frame]) + + # Adding the 2 selector tensors together produces a [num_frames, + # subframes_per_frame] tensor of indices to use with tf.gather to select + # subframes from subframes. We then reshape the inner-most subframes_per_frame + # dimension to stitch the subframes together into frames. For example: + # [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]]. + selector = frame_selector + subframe_selector + frames = tf.reshape( + tf.gather(subframes, selector.astype(np.int32), axis=gather_axis), output_shape + ) + + return frames + + +def stft_tflite(signal, frame_length, frame_step, fft_length, window_fn, pad_end): + """tflite-compatible implementation of tf.signal.stft. + Compute the short-time Fourier transform of a 1D input while avoiding tf ops + that are not currently supported in tflite (Rfft, Range, SplitV). + fft_length must be fixed. A Hann window is of frame_length is always + applied. + Since fixed (precomputed) framing must be used, signal.shape[-1] must be a + specific value (so "?"/None is not supported). + + Args: + signal: 1D tensor containing the time-domain waveform to be transformed. + frame_length: int, the number of points in each Fourier frame. + frame_step: int, the number of samples to advance between successive frames. + fft_length: int, the size of the Fourier transform to apply. + window_fn: tf.signal.window, the return of backend.get_window_fn(window_name) + pad_end: bool, if true pads the end with zeros so that signal contains + an integer number of frames + + Returns: + spectrogram: Two (num_frames, fft_length) tensors containing the real and + imaginary parts of the short-time Fourier transform of the input signal. + """ + signal = tf.cast(signal, tf.float32) + if pad_end: + # the number of whole frames + length_samples = signal.shape[-1] + num_steps_round_up = tf.math.ceil(length_samples / frame_step) + pad_amount = int((num_steps_round_up * frame_step) - length_samples) + signal = tf.pad(signal, tf.constant([[0, 0], [0, 0], [0, pad_amount]])) + # Make the window be shape (1, frame_length) instead of just frame_length + # in an effort to help the tflite broadcast logic. + window = tf.reshape(window_fn(frame_length), [1, frame_length]) + + framed_signal = fixed_frame(signal, frame_length, frame_step) + framed_signal *= window + + spectrogram = _rdft(framed_signal, fft_length) + + return spectrogram + + +@tf.function +def continued_fraction_arctan(x, n=100, dtype=tf.float32): + """Continued fraction Approximation to the arctan function + + Approximate solution to arctan(x), atan is not a natively supported tflite + op (or a flex op). n is the number of iterations, the high the more accurate. + Accuracy is poor when the argument is large. + https://functions.wolfram.com/ElementaryFunctions/ArcTan/10/ + + Args: + x (tensor) - argument tensor to caclulate arctan of + n (int) - The number of iterations, large means arctan is more accurate + dtype (tf.dtype) - tf.float32, or tf.float64 + + Returns + arctan(x) (tensor) - approx value of arctan(x) + """ + x = tf.cast(x, dtype) + x2 = x * x + d = tf.zeros(tf.shape(x), dtype) + tf.cast(n * 2 + 1, dtype) + for k in tf.range(n, 0.0, -1.0, dtype): + f = k * 2.0 - 1.0 + d = f + k * k * x2 / d + return x / d + + +def atan2_tflite(y, x, n=100, dtype=tf.float32): + """Approximation to the atan2 function + + atan is not a tflite supported op or flex op, thus this uses an Approximation + Poor accuracy when either x is very small or y is very large. + https://en.wikipedia.org/wiki/Atan2 + + Args: + y (tensor) - vertical component of tangent (or imaginary part of number for phase) + x (tensor) - horizontal component of tanget (or real part of number for phase) + n (int) - The number of iterations to use for atan approximations, + larger means arctan is more accurate + dtype (tf.dtype) - tf.float32, or tf.float64 + + Returns + atan2(x) (tensor) - approx value of atan2(x) + """ + pi = tf.zeros(tf.shape(x), dtype) + tf.cast(np.pi, dtype) + zeros = tf.zeros(tf.shape(x), dtype) + atan2 = continued_fraction_arctan(y / x, n, dtype) + atan2 = tf.where(x > 0, atan2, atan2) # implicit + atan2 = tf.where(tf.logical_and(x < 0.0, y >= 0.0), atan2 + pi, atan2) + atan2 = tf.where(tf.logical_and(x < 0.0, y < 0.0), atan2 - pi, atan2) + atan2 = tf.where(tf.logical_and(tf.equal(x, 0.0), y > 0.0), pi, atan2) + atan2 = tf.where(tf.logical_and(tf.equal(x, 0.0), y < 0.0), -pi, atan2) + # undefined (return 0) + atan2 = tf.where(tf.logical_and(tf.equal(x, 0.0), tf.equal(y, 0.0)), zeros, atan2) + return atan2 diff --git a/kapre/time_frequency.py b/kapre/time_frequency.py index 2d1c4a7..bec6a81 100644 --- a/kapre/time_frequency.py +++ b/kapre/time_frequency.py @@ -19,11 +19,12 @@ """ import tensorflow as tf -from tensorflow.keras.layers import Layer, Conv2D -from . import backend from tensorflow.keras import backend as K -from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR +from tensorflow.keras.layers import Layer +from . import backend +from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR +from .tflite_compatible_stft import atan2_tflite __all__ = [ 'STFT', @@ -138,10 +139,10 @@ def call(self, x): Return: (complex `Tensor`): A STFT representation of x in a 2D batch shape. - `complex64` if `x` is `float32`. `complex128` if `x` is `float64`. + `complex64` if `x` is `float32`, `complex128` if `x` is `float64`. Its shape is (batch, time, freq, ch) or (batch. ch, time, freq) depending on `output_data_format` and - `time` is the number of frames, which is `((len_src + (win_length - hop_length) / hop_length) // win_length )` if `pad_end` is `True`. - `freq` is the number of fft unique bins, which is `n_fft // 2 + 1` (the unique components of the FFT). + `time` is the number of frames, which is `((len_src + (win_length - hop_length) / hop_length) // win_length )` if `pad_end` is `True`. + `freq` is the number of fft unique bins, which is `n_fft // 2 + 1` (the unique components of the FFT). """ waveforms = x # (batch, ch, time) if input_data_format == 'channels_first'. # (batch, time, ch) if input_data_format == 'channels_last'. @@ -335,6 +336,17 @@ def call(self, x): class Phase(Layer): """Compute the phase of the complex input in radian, resulting in a float tensor + Includes option to use approximate phase algorithm this will return the same + results as the PhaseTflite layer (the tflite compatible layer). + + Args: + approx_atan_accuracy (`int`): if `None` will use tf.math.angle() to + calculate the phase accurately. If an `int` this is the number of + iterations to calculate the approximate atan() using a tflite compatible + method. the higher the number the more accurate e.g. + approx_atan_accuracy=29000. You may want to experiment with adjusting + this number: trading off accuracy with inference speed. + Example: :: @@ -343,9 +355,12 @@ class Phase(Layer): model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) model.add(Phase()) # now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float - """ + def __init__(self, approx_atan_accuracy=None, **kwargs): + super(Phase, self).__init__(**kwargs) + self.approx_atan_accuracy = approx_atan_accuracy + def call(self, x): """ Args: @@ -354,8 +369,20 @@ def call(self, x): Returns: (float `Tensor`): phase of `x` (Radian) """ + if self.approx_atan_accuracy: + return atan2_tflite(tf.math.imag(x), tf.math.real(x), n=self.approx_atan_accuracy) + return tf.math.angle(x) + def get_config(self): + config = super(Phase, self).get_config() + config.update( + { + 'tflite_phase_accuracy': self.approx_atan_accuracy, + } + ) + return config + class MagnitudeToDecibel(Layer): """A class that wraps `backend.magnitude_to_decibel` to compute decibel of the input magnitude. @@ -401,7 +428,11 @@ def call(self, x): def get_config(self): config = super(MagnitudeToDecibel, self).get_config() config.update( - {'amin': self.amin, 'dynamic_range': self.dynamic_range, 'ref_value': self.ref_value,} + { + 'amin': self.amin, + 'dynamic_range': self.dynamic_range, + 'ref_value': self.ref_value, + } ) return config @@ -439,7 +470,11 @@ class ApplyFilterbank(Layer): """ def __init__( - self, type, filterbank_kwargs, data_format='default', **kwargs, + self, + type, + filterbank_kwargs, + data_format='default', + **kwargs, ): backend.validate_data_format_str(data_format) @@ -659,6 +694,8 @@ def _concat_frequency_map(self, inputs): def get_config(self): config = super(ConcatenateFrequencyMap, self).get_config() config.update( - {'data_format': self.data_format,} + { + 'data_format': self.data_format, + } ) return config diff --git a/kapre/time_frequency_tflite.py b/kapre/time_frequency_tflite.py new file mode 100644 index 0000000..0011796 --- /dev/null +++ b/kapre/time_frequency_tflite.py @@ -0,0 +1,181 @@ +"""Tflite compatible versions of Kapre layers. + +`STFTTflite` is a tflite compatible version of `STFT`. Tflite does not support complex +types, thus real and imaginary parts are returned as an extra (last) dimension. +Ouput shape is now: `(batch, channel, time, re/im)` or `(batch, time, channel, re/im)`. + +Because of the change of dimension, Tflite compatible layers are provided to +process the resulting STFT; `MagnitudeTflite` and `PhaseTflite` are layers that +calculate the magnitude and phase respectively from the output of `STFTTflite`. +""" +import tensorflow as tf +from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR +from .tflite_compatible_stft import stft_tflite, atan2_tflite + +# import non-tflite compatible layers to inheret from. +from .time_frequency import STFT, InverseSTFT, Magnitude, Phase + + +__all__ = [ + 'STFTTflite', + # 'InverseSTFTTflite', # NOTE (PK): todo + 'MagnitudeTflite', + 'PhaseTflite', +] + + +class STFTTflite(STFT): + """ + A Short-time Fourier transform layer (tflite compatible). + + Ues `stft_tflite` from tflite_compatible_stft.py, this contains a tflite + compatible stft (using a rdft), and `fixed_frame()` to window the audio. + Tflite does not cope with comple types so real and imaginary parts are stored in extra dim. + Ouput shape is now: (batch, channel, time, re/im) or (batch, time, channel, re/im) + + Additionally, it reshapes the output to be a proper 2D batch. + + If `output_data_format == 'channels_last'`, the output shape is `(batch, time, freq, channel, re/imag)` + If `output_data_format == 'channels_first'`, the output shape is `(batch, channel, time, freq, re/imag)` + + Args: + n_fft (int): Number of FFTs. Defaults to `2048` + win_length (int or None): Window length in sample. Defaults to `n_fft`. + hop_length (int or None): Hop length in sample between analysis windows. Defaults to `n_fft // 4` following Librosa. + window_name (str or None): *Name* of `tf.signal` function that returns a 1D tensor window that is used in analysis. + Defaults to `hann_window` which uses `tf.signal.hann_window`. + Window availability depends on Tensorflow version. More details are at `kapre.backend.get_window()`. + pad_begin (bool): Whether to pad with zeros along time axis (length: win_length - hop_length). Defaults to `False`. + pad_end (bool): Whether to pad with zeros at the finishing end of the signal. + input_data_format (str): the audio data format of input waveform batch. + `'channels_last'` if it's `(batch, time, channels)` and + `'channels_first'` if it's `(batch, channels, time)`. + Defaults to the setting of your Keras configuration. (`tf.keras.backend.image_data_format()`) + output_data_format (str): The data format of output STFT. + `'channels_last'` if you want `(batch, time, frequency, channels)` and + `'channels_first'` if you want `(batch, channels, time, frequency)` + Defaults to the setting of your Keras configuration. (`tf.keras.backend.image_data_format()`) + + **kwargs: Keyword args for the parent keras layer (e.g., `name`) + + Example: + :: + + input_shape = (2048, 1) # mono signal + model = Sequential() # tflite compatible model + model.add(kapre.STFTTflite(n_fft=1024, hop_length=512, input_shape=input_shape)) + # now the shape is (batch, n_frame=3, n_freq=513, ch=1, re/im=2) + # and the dtype is real + + """ + + def call(self, x): + """ + Compute STFT of the input signal. If the `time` axis is not the last axis of `x`, it should be transposed first. + + Args: + x (float `Tensor`): batch of audio signals, (batch, ch, time) or (batch, time, ch) based on input_data_format + + Return: + (real `Tensor`): A STFT representation of x in a 2D batch shape. The last dimension is size two and contains + the real and imaginary parts of the stft. + Its shape is (batch, time, freq, ch, 2) or (batch. ch, time, freq, 2) depending on `output_data_format` and + `time` is the number of frames, which is `((len_src + (win_length - hop_length) / hop_length) // win_length )` + if `pad_end` is `True`. `freq` is the number of fft unique bins, which is `n_fft // 2 + 1` (the unique components of the FFT). + """ + waveforms = x # (batch, ch, time) if input_data_format == 'channels_first'. + # (batch, time, ch) if input_data_format == 'channels_last'. + + # this is needed because tf.signal.stft lives in channels_first land. + if self.input_data_format == _CH_LAST_STR: + waveforms = tf.transpose( + waveforms, perm=(0, 2, 1) + ) # always (batch, ch, time) from here + + if self.pad_begin: + waveforms = tf.pad( + waveforms, tf.constant([[0, 0], [0, 0], [int(self.n_fft - self.hop_length), 0]]) + ) + stfts = stft_tflite( + waveforms, + frame_length=self.win_length, + frame_step=self.hop_length, + fft_length=self.n_fft, + window_fn=self.window_fn, + pad_end=self.pad_end, + ) # (batch, ch, time, freq, re/imag) + + if self.output_data_format == _CH_LAST_STR: + # tflite compatible stft produces real and imag in 1st dim + stfts = tf.transpose(stfts, perm=(0, 2, 3, 1, 4)) # (batch, t, f, ch, re/im) + + return stfts + + +class MagnitudeTflite(Magnitude): + """Compute the magnitude of the input (tflite compatible). + + The input is a real tensor, the last dimension has a size of `2` + representing real and imaginary parts respectively. + + Example: + :: + + input_shape = (2048, 1) # mono signal + model = Sequential() + model.add(kapre.STFTTflite(n_fft=1024, hop_length=512, input_shape=input_shape)) + mode.add(MagnitudeTflite()) + # now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float + + """ + + def call(self, x): + """ + Args: + x (real or complex `Tensor`): input is real tensor whose last + dimension has a size of `2` and represents real and imaginary + parts + + Returns: + (float `Tensor`): magnitude of `x` + """ + return tf.norm(x, ord='euclidean', axis=-1) + + +class PhaseTflite(Phase): + """Compute the phase of the complex input in radian, resulting in a float tensor (tflite compatible). + + Note TF lite does not natively support atan, used in tf.math.angle, so an + approximation is provided. You may want to use this approximation if you + generate data using a non-tf-lite compatible STFT (faster) but want the same + approximations in the training data. + + Args: + approx_atan_accuracy (`int`): if `None` will use `tf.math.angle()` to + calculate the phase accurately. If an `int` this is the number of + iterations to calculate the approximate `atan()` using a tflite compatible + method. the higher the number the more accurate e.g. + `approx_atan_accuracy=29000`. You may want to experiment with adjusting + this number: trading off accuracy with inference speed. + + Example: + :: + + input_shape = (2048, 1) # mono signal + model = Sequential() + model.add(kapre.STFTTflite(n_fft=1024, hop_length=512, input_shape=input_shape)) + model.add(PhaseTflite(approx_atan_accuracy=5000)) + # now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float + + """ + + def call(self, x): + """ + Args: + x (real): input is real tensor with five + dimensions (last dim is re/imag) + + Returns: + (float `Tensor`): phase of `x` (Radian) + """ + return atan2_tflite(x[:, :, :, :, 1], x[:, :, :, :, 0], n=self.approx_atan_accuracy) diff --git a/setup.py b/setup.py index 5aba284..c2acad3 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='kapre', - version='0.3.4', + version='0.3.5', description='Kapre: Keras Audio Preprocessors. Tensorflow.Keras layers for audio pre-processing in deep learning', author='Keunwoo Choi', url='http://github.com/keunwoochoi/kapre/', @@ -18,7 +18,7 @@ install_requires=[ 'numpy >= 1.18.5', 'librosa >= 0.7.2', - 'tensorflow >= 2.0', + 'tensorflow >= 2.0.0' ], keywords='audio music speech sound deep learning keras tensorflow', zip_safe=False, diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index a4f033b..2fc7235 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -15,7 +15,11 @@ def test_channel_swap_correctness(n_ch, data_format, data_type): src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch, length=len_src) model = tf.keras.Sequential() - model.add(ChannelSwap(input_shape=input_shape,)) + model.add( + ChannelSwap( + input_shape=input_shape, + ) + ) # consistent during inference kapre_ref = model.predict(batch_src) for _ in range(100): diff --git a/tests/test_backend.py b/tests/test_backend.py index b3c027c..4967765 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -12,10 +12,13 @@ @pytest.mark.parametrize('dynamic_range', [80.0, 120.0]) -def test_magnitude_to_decibel(dynamic_range): +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +def test_magnitude_to_decibel(dynamic_range, dtype: str): """test for backend_keras.magnitude_to_decibel""" - x = np.array([[1e-20, 1e-5, 1e-3, 5e-2], [0.3, 1.0, 20.5, 9999]]) # random positive numbers + x = np.array( + [[1e-20, 1e-5, 1e-3, 5e-2], [0.3, 1.0, 20.5, 9999]], dtype=dtype + ) # random positive numbers amin = 1e-5 x_decibel_ref = np.stack( @@ -30,8 +33,10 @@ def test_magnitude_to_decibel(dynamic_range): x_decibel_kapre = magnitude_to_decibel( x_var, ref_value=1.0, amin=amin, dynamic_range=dynamic_range ) - - np.testing.assert_allclose(K.eval(x_decibel_kapre), x_decibel_ref, atol=TOL) + if dtype == 'float16': + np.testing.assert_allclose(K.eval(x_decibel_kapre), x_decibel_ref, rtol=1e-3, atol=TOL) + else: + np.testing.assert_allclose(K.eval(x_decibel_kapre), x_decibel_ref, atol=TOL) @pytest.mark.parametrize('sample_rate', [44100, 22050]) diff --git a/tests/test_time_frequency.py b/tests/test_time_frequency.py index 6c08c1a..7b1d6a4 100644 --- a/tests/test_time_frequency.py +++ b/tests/test_time_frequency.py @@ -11,6 +11,9 @@ InverseSTFT, ApplyFilterbank, ConcatenateFrequencyMap, + STFTTflite, + MagnitudeTflite, + PhaseTflite, ) from kapre.composed import ( get_melspectrogram_layer, @@ -21,7 +24,7 @@ get_frequency_aware_conv2d, ) -from utils import get_audio, save_load_compare +from utils import get_audio, save_load_compare, predict_using_tflite def _num_frame_valid(nsp_src, nsp_win, len_hop): @@ -46,6 +49,17 @@ def allclose_phase(a, b, atol=1e-3): np.testing.assert_allclose(np.cos(a), np.cos(b), atol=atol) +def assert_approx_phase(a, b, atol=1e-2, acceptable_fail_ratio=0.01): + """Testing approximate phase. + Tflite phase is approximate, some values will allways have a large error + So makes more sense to count the number that are within tolerance + """ + count_failed = np.sum(np.abs(a - b) > atol) + assert ( + count_failed / a.size < acceptable_fail_ratio + ), "too many inaccuracte phase bins: {} bins out of {} incorrect".format(count_failed, a.size) + + def allclose_complex_numbers(a, b, atol=1e-3): np.testing.assert_equal(np.shape(a), np.shape(b)) np.testing.assert_allclose(np.abs(a), np.abs(b), rtol=1e-5, atol=atol) @@ -57,7 +71,8 @@ def allclose_complex_numbers(a, b, atol=1e-3): @pytest.mark.parametrize('hop_length', [None, 256]) @pytest.mark.parametrize('n_ch', [1, 2, 6]) @pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last']) -def test_spectrogram_correctness(n_fft, hop_length, n_ch, data_format): +@pytest.mark.parametrize('batch_size', [1, 10]) +def test_spectrogram_correctness(n_fft, hop_length, n_ch, data_format, batch_size): def _get_stft_model(following_layer=None): # compute with kapre stft_model = tensorflow.keras.models.Sequential() @@ -78,7 +93,9 @@ def _get_stft_model(following_layer=None): stft_model.add(following_layer) return stft_model - src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch) + src_mono, batch_src, input_shape = get_audio( + data_format=data_format, n_ch=n_ch, batch_size=batch_size + ) win_length = n_fft # test with x2 # compute with librosa S_ref = librosa.core.stft( @@ -246,6 +263,96 @@ def _get_melgram_model(return_decibel, amin, dynamic_range, input_shape=None): ) # decibel is evaluated with relative tolerance +@pytest.mark.parametrize('n_fft', [1000]) +@pytest.mark.parametrize('hop_length', [None, 256]) +@pytest.mark.parametrize('n_ch', [1, 2]) +@pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last']) +@pytest.mark.parametrize('batch_size', [1, 2]) +@pytest.mark.parametrize('win_length', [1000, 512]) +def test_spectrogram_tflite_correctness( + n_fft, hop_length, n_ch, data_format, batch_size, win_length +): + def _get_stft_model(following_layer=None, tflite_compatible=False): + # compute with kapre + stft_model = tensorflow.keras.models.Sequential() + if tflite_compatible: + stft_model.add( + STFTTflite( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_name=None, + pad_end=False, + input_data_format=data_format, + output_data_format=data_format, + input_shape=input_shape, + name='stft', + ) + ) + else: + stft_model.add( + STFT( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_name=None, + pad_end=False, + input_data_format=data_format, + output_data_format=data_format, + input_shape=input_shape, + name='stft', + ) + ) + if following_layer is not None: + stft_model.add(following_layer) + return stft_model + + src_mono, batch_src, input_shape = get_audio( + data_format=data_format, n_ch=n_ch, batch_size=batch_size + ) + # tflite requires a known batch size + batch_size = batch_src.shape[0] + + stft_model_tflite = _get_stft_model(tflite_compatible=True) + stft_model = _get_stft_model(tflite_compatible=False) + + # test STFT() + S_complex_tflite = predict_using_tflite(stft_model_tflite, batch_src) # predict using tflite + # (batch, time, freq, chan, re/imag) - convert to complex number: + S_complex_tflite = tf.complex( + S_complex_tflite[..., 0], S_complex_tflite[..., 1] + ) # (batch,time,freq,chan) + S_complex = stft_model.predict(batch_src) # predict using tf model + allclose_complex_numbers(S_complex, S_complex_tflite) + + # test Magnitude() + stft_mag_model_tflite = _get_stft_model(MagnitudeTflite(), tflite_compatible=True) + stft_mag_model = _get_stft_model(Magnitude(), tflite_compatible=False) + S_lite = predict_using_tflite(stft_mag_model_tflite, batch_src) # predict using tflite + S = stft_mag_model.predict(batch_src) # predict using tf model + np.testing.assert_allclose(S, S_lite, atol=1e-4) + + # # test approx Phase() same for tflite and non-tflite + stft_approx_phase_model_lite = _get_stft_model( + PhaseTflite(approx_atan_accuracy=500), tflite_compatible=True + ) + stft_approx_phase_model = _get_stft_model( + Phase(approx_atan_accuracy=500), tflite_compatible=False + ) + S_approx_phase_lite = predict_using_tflite( + stft_approx_phase_model_lite, batch_src + ) # predict using tflite + S_approx_phase = stft_approx_phase_model.predict( + batch_src, batch_size=batch_size + ) # predict using tf model + assert_approx_phase(S_approx_phase_lite, S_approx_phase, atol=1e-2, acceptable_fail_ratio=0.01) + + # # test accuracy of approx Phase() + stft_phase_model = _get_stft_model(Phase(), tflite_compatible=False) + S_phase = stft_phase_model.predict(batch_src, batch_size=batch_size) # predict using tf model + assert_approx_phase(S_approx_phase_lite, S_phase, atol=1e-2, acceptable_fail_ratio=0.01) + + @pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last']) def test_log_spectrogram_runnable(data_format): """test if log spectrogram layer works well""" @@ -298,7 +405,11 @@ def test_mag_phase(data_format): mag_phase_ref = np.stack( librosa.magphase( librosa.stft( - src_mono, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, + src_mono, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=False, ).T ), axis=ch_axis, @@ -306,8 +417,20 @@ def test_mag_phase(data_format): np.testing.assert_equal(mag_phase_kapre.shape, mag_phase_ref.shape) # magnitude test np.testing.assert_allclose( - np.take(mag_phase_kapre, [0,], axis=ch_axis,), - np.take(mag_phase_ref, [0,], axis=ch_axis,), + np.take( + mag_phase_kapre, + [ + 0, + ], + axis=ch_axis, + ), + np.take( + mag_phase_ref, + [ + 0, + ], + axis=ch_axis, + ), atol=2e-4, ) # phase test - todo - yeah.. diff --git a/tests/utils.py b/tests/utils.py index f922a76..577d541 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ import os +from pathlib import Path +import shutil import numpy as np import tensorflow as tf import tempfile @@ -7,7 +9,7 @@ SRC = np.load('tests/speech_test_file.npz')['audio_data'].astype(np.float32) -def get_audio(data_format, n_ch, length=8000): +def get_audio(data_format, n_ch, length=8000, batch_size=1): src = SRC src = src[:length] src_mono = src.copy() @@ -26,7 +28,8 @@ def get_audio(data_format, n_ch, length=8000): src = np.transpose(src) # (ch, time) input_shape = (n_ch, len_src) - batch_src = np.expand_dims(src, axis=0) # 3d batch input + # batch_src = np.expand_dims(src, axis=0) # 3d batch input + batch_src = np.repeat([src], batch_size, axis=0) return src_mono, batch_src, input_shape @@ -69,3 +72,57 @@ def save_load_compare( model_temp_dir.cleanup() return model + + +def predict_using_tflite(model, batch_src): + """Convert a keras model to tflite and infer on batch_src + + Attempts to convert a keras model to a tflite model, load the tflite model, + then infer on the data in batch_src + Args: + model (keras model) + batch_src (numpy array) - audio to test model + Returns: + pred_tflite (numpy array) - array of predictions. + """ + ############################################################################ + # TF lite conversion + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.SELECT_TF_OPS, + tf.lite.OpsSet.TFLITE_BUILTINS, + ] + tflite_model = converter.convert() + model_name = 'test_tflite' + path = Path("/tmp/tflite_tests/") + # make a temporary location + if path.exists(): + shutil.rmtree(path) + os.makedirs(path) + tflite_file = path / Path(model_name + ".tflite") + open(tflite_file.as_posix(), "wb").write(tflite_model) + + ############################################################################ + # Make sure we can load and infer on the TFLITE model + interpreter = tf.lite.Interpreter(tflite_file.as_posix()) + # infer on each input seperately and collect the predictions + pred_tflite = [] + + for x in batch_src: + + # set batch size for tflite + interpreter.allocate_tensors() + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # apply input tensors, expand first dimension to create batch dimension + interpreter.set_tensor(input_details[0]["index"], np.expand_dims(x, 0)) + # infer + interpreter.invoke() + tflite_results = interpreter.get_tensor(output_details[0]["index"]) + + pred_tflite.append(tflite_results) + + return np.concatenate(pred_tflite, axis=0)