From cc06786d90893fda03830dffb6dfadc523f3132e Mon Sep 17 00:00:00 2001 From: fchollet Date: Sat, 1 Oct 2016 00:04:07 -0700 Subject: [PATCH] Add music tagger model --- README.md | 6 +- audio_conv_utils.py | 86 +++++++++++++++++++++++ music_tagger_crnn.py | 162 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 audio_conv_utils.py create mode 100644 music_tagger_crnn.py diff --git a/README.md b/README.md index ad89ce8..464377e 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,11 @@ This repository contains code for the following Keras models: - VGG19 - ResNet50 - Inception v3 +- CRNN for music tagging All architectures are compatible with both TensorFlow and Theano, and upon instantiation the models will be built according to the image dimension ordering set in your Keras configuration file at `~/.keras/keras.json`. For instance, if you have set `image_dim_ordering=tf`, then any model loaded from this repository will get built according to the TensorFlow dimension ordering convention, "Width-Height-Depth". -Weights can be automatically loaded upon instantiation (`weights='imagenet'` argument in model constructor). Weights are automatically downloaded if necessary, and cached locally in `~/.keras/models/`. - -**Note that using these models requires the latest version of Keras (from the Github repo, not PyPI).** +Pre-trained weights can be automatically loaded upon instantiation (`weights='imagenet'` argument in model constructor for all image models, `weights='msd'` for the music tagging model). Weights are automatically downloaded if necessary, and cached locally in `~/.keras/models/`. ## Examples @@ -78,6 +77,7 @@ block4_pool_features = model.predict(x) - [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) - please cite this paper if you use the VGG models in your work. - [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) - please cite this paper if you use the ResNet model in your work. - [Rethinking the Inception Architecture for Computer Vision](http://arxiv.org/abs/1512.00567) - please cite this paper if you use the Inception v3 model in your work. +- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras) Additionally, don't forget to [cite Keras](https://keras.io/getting-started/faq/#how-should-i-cite-keras) if you use these models. diff --git a/audio_conv_utils.py b/audio_conv_utils.py new file mode 100644 index 0000000..2e4d93d --- /dev/null +++ b/audio_conv_utils.py @@ -0,0 +1,86 @@ +import numpy as np +from keras import backend as K + + +TAGS = ['rock', 'pop', 'alternative', 'indie', 'electronic', + 'female vocalists', 'dance', '00s', 'alternative rock', 'jazz', + 'beautiful', 'metal', 'chillout', 'male vocalists', + 'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', + '80s', 'folk', '90s', 'chill', 'instrumental', 'punk', + 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic', + 'experimental', 'female vocalist', 'guitar', 'Hip-Hop', + '70s', 'party', 'country', 'easy listening', + 'sexy', 'catchy', 'funk', 'electro', 'heavy metal', + 'Progressive rock', '60s', 'rnb', 'indie pop', + 'sad', 'House', 'happy'] + + +def librosa_exists(): + try: + __import__('librosa') + except ImportError: + return False + else: + return True + + +def preprocess_input(audio_path, dim_ordering='default'): + '''Reads an audio file and outputs a Mel-spectrogram. + ''' + if dim_ordering == 'default': + dim_ordering = K.image_dim_ordering() + assert dim_ordering in {'tf', 'th'} + + if librosa_exists(): + import librosa + else: + raise RuntimeError('Librosa is required to process audio files.\n' + + 'Install it via `pip install librosa` \nor visit ' + + 'http://librosa.github.io/librosa/ for details.') + + # mel-spectrogram parameters + SR = 12000 + N_FFT = 512 + N_MELS = 96 + HOP_LEN = 256 + DURA = 29.12 + + src, sr = librosa.load(audio_path, sr=SR) + n_sample = src.shape[0] + n_sample_wanted = int(DURA * SR) + + # trim the signal at the center + if n_sample < n_sample_wanted: # if too short + src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,)))) + elif n_sample > n_sample_wanted: # if too long + src = src[(n_sample - n_sample_wanted) / 2: + (n_sample + n_sample_wanted) / 2] + + logam = librosa.logamplitude + melgram = librosa.feature.melspectrogram + x = logam(melgram(y=src, sr=SR, hop_length=HOP_LEN, + n_fft=N_FFT, n_mels=N_MELS) ** 2, + ref_power=1.0) + + if dim_ordering == 'th': + x = np.expand_dims(x, axis=0) + elif dim_ordering == 'tf': + x = np.expand_dims(x, axis=3) + return x + + +def decode_predictions(preds, top_n=5): + '''Decode the output of a music tagger model. + + # Arguments + preds: 2-dimensional numpy array + top_n: integer in [0, 50], number of items to show + + ''' + assert len(preds.shape) == 2 and preds.shape[1] == 50 + results = [] + for pred in preds: + result = zip(TAGS, pred) + result = sorted(result, key=lambda x: x[1], reverse=True) + results.append(result[:top_n]) + return results diff --git a/music_tagger_crnn.py b/music_tagger_crnn.py new file mode 100644 index 0000000..3f465e2 --- /dev/null +++ b/music_tagger_crnn.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +'''MusicTaggerCRNN model for Keras. + +Code by github.com/keunwoochoi. + +# Reference: + +- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras) + +''' +from __future__ import print_function +from __future__ import absolute_import + +import numpy as np +from keras import backend as K +from keras.layers import Input, Dense +from keras.models import Model +from keras.layers import Dense, Dropout, Reshape, Permute +from keras.layers.convolutional import Convolution2D +from keras.layers.convolutional import MaxPooling2D, ZeroPadding2D +from keras.layers.normalization import BatchNormalization +from keras.layers.advanced_activations import ELU +from keras.layers.recurrent import GRU +from keras.utils.data_utils import get_file +from keras.utils.layer_utils import convert_all_kernels_in_model +from audio_conv_utils import decode_predictions, preprocess_input + +TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5' +TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5' + + +def MusicTaggerCRNN(weights='msd', input_tensor=None, + include_top=True): + '''Instantiate the MusicTaggerCRNN architecture, + optionally loading weights pre-trained + on Million Song Dataset. Note that when using TensorFlow, + for best performance you should set + `image_dim_ordering="tf"` in your Keras config + at ~/.keras/keras.json. + + The model and the weights are compatible with both + TensorFlow and Theano. The dimension ordering + convention used by the model is the one + specified in your Keras config file. + + For preparing mel-spectrogram input, see + `audio_conv_utils.py` in [applications](https://github.com/fchollet/keras/tree/master/keras/applications). + You will need to install [Librosa](http://librosa.github.io/librosa/) + to use it. + + # Arguments + weights: one of `None` (random initialization) + or "msd" (pre-training on ImageNet). + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. + include_top: whether to include the 1 fully-connected + layer (output layer) at the top of the network. + If False, the network outputs 32-dim features. + + + # Returns + A Keras model instance. + ''' + if weights not in {'msd', None}: + raise ValueError('The `weights` argument should be either ' + '`None` (random initialization) or `msd` ' + '(pre-training on Million Song Dataset).') + + # Determine proper input shape + if K.image_dim_ordering() == 'th': + input_shape = (1, 96, 1366) + else: + input_shape = (96, 1366, 1) + + if input_tensor is None: + melgram_input = Input(shape=input_shape) + else: + if not K.is_keras_tensor(input_tensor): + melgram_input = Input(tensor=input_tensor, shape=input_shape) + else: + melgram_input = input_tensor + + # Determine input axis + if K.image_dim_ordering() == 'th': + channel_axis = 1 + freq_axis = 2 + time_axis = 3 + else: + channel_axis = 3 + freq_axis = 1 + time_axis = 2 + + # Input block + x = ZeroPadding2D(padding=(0, 37))(melgram_input) + x = BatchNormalization(axis=time_axis, name='bn_0_freq')(x) + + # Conv block 1 + x = Convolution2D(64, 3, 3, border_mode='same', name='conv1')(x) + x = BatchNormalization(axis=channel_axis, mode=0, name='bn1')(x) + x = ELU()(x) + x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(x) + + # Conv block 2 + x = Convolution2D(128, 3, 3, border_mode='same', name='conv2')(x) + x = BatchNormalization(axis=channel_axis, mode=0, name='bn2')(x) + x = ELU()(x) + x = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name='pool2')(x) + + # Conv block 3 + x = Convolution2D(128, 3, 3, border_mode='same', name='conv3')(x) + x = BatchNormalization(axis=channel_axis, mode=0, name='bn3')(x) + x = ELU()(x) + x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool3')(x) + + # Conv block 4 + x = Convolution2D(128, 3, 3, border_mode='same', name='conv4')(x) + x = BatchNormalization(axis=channel_axis, mode=0, name='bn4')(x) + x = ELU()(x) + x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool4')(x) + + # reshaping + if K.image_dim_ordering() == 'th': + x = Permute((3, 1, 2))(x) + x = Reshape((15, 128))(x) + + # GRU block 1, 2, output + x = GRU(32, return_sequences=True, name='gru1')(x) + x = GRU(32, return_sequences=False, name='gru2')(x) + + if include_top: + x = Dense(50, activation='sigmoid', name='output')(x) + + # Create model + model = Model(melgram_input, x) + if weights is None: + return model + else: + # Load weights + if K.image_dim_ordering() == 'tf': + weights_path = get_file('music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5', + TF_WEIGHTS_PATH, + cache_subdir='models') + else: + weights_path = get_file('music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5', + TH_WEIGHTS_PATH, + cache_subdir='models') + model.load_weights(weights_path, by_name=True) + if K.backend() == 'theano': + convert_all_kernels_in_model(model) + return model + + +if __name__ == '__main__': + model = MusicTaggerCRNN(weights='msd') + + audio_path = 'audio_file.mp3' + melgram = preprocess_input(audio_path) + melgrams = np.expand_dims(melgram, axis=0) + + preds = model.predict(melgrams) + print('Predicted:') + print(decode_predictions(preds))