Skip to content

Commit

Permalink
Add music tagger model
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 1, 2016
1 parent 170595b commit cc06786
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 3 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ This repository contains code for the following Keras models:
- VGG19
- ResNet50
- Inception v3
- CRNN for music tagging

All architectures are compatible with both TensorFlow and Theano, and upon instantiation the models will be built according to the image dimension ordering set in your Keras configuration file at `~/.keras/keras.json`. For instance, if you have set `image_dim_ordering=tf`, then any model loaded from this repository will get built according to the TensorFlow dimension ordering convention, "Width-Height-Depth".

Weights can be automatically loaded upon instantiation (`weights='imagenet'` argument in model constructor). Weights are automatically downloaded if necessary, and cached locally in `~/.keras/models/`.

**Note that using these models requires the latest version of Keras (from the Github repo, not PyPI).**
Pre-trained weights can be automatically loaded upon instantiation (`weights='imagenet'` argument in model constructor for all image models, `weights='msd'` for the music tagging model). Weights are automatically downloaded if necessary, and cached locally in `~/.keras/models/`.

## Examples

Expand Down Expand Up @@ -78,6 +77,7 @@ block4_pool_features = model.predict(x)
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) - please cite this paper if you use the VGG models in your work.
- [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) - please cite this paper if you use the ResNet model in your work.
- [Rethinking the Inception Architecture for Computer Vision](http://arxiv.org/abs/1512.00567) - please cite this paper if you use the Inception v3 model in your work.
- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras)

Additionally, don't forget to [cite Keras](https://keras.io/getting-started/faq/#how-should-i-cite-keras) if you use these models.

Expand Down
86 changes: 86 additions & 0 deletions audio_conv_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import numpy as np
from keras import backend as K


TAGS = ['rock', 'pop', 'alternative', 'indie', 'electronic',
'female vocalists', 'dance', '00s', 'alternative rock', 'jazz',
'beautiful', 'metal', 'chillout', 'male vocalists',
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica',
'80s', 'folk', '90s', 'chill', 'instrumental', 'punk',
'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
'experimental', 'female vocalist', 'guitar', 'Hip-Hop',
'70s', 'party', 'country', 'easy listening',
'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
'Progressive rock', '60s', 'rnb', 'indie pop',
'sad', 'House', 'happy']


def librosa_exists():
try:
__import__('librosa')
except ImportError:
return False
else:
return True


def preprocess_input(audio_path, dim_ordering='default'):
'''Reads an audio file and outputs a Mel-spectrogram.
'''
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
assert dim_ordering in {'tf', 'th'}

if librosa_exists():
import librosa
else:
raise RuntimeError('Librosa is required to process audio files.\n' +
'Install it via `pip install librosa` \nor visit ' +
'http://librosa.github.io/librosa/ for details.')

# mel-spectrogram parameters
SR = 12000
N_FFT = 512
N_MELS = 96
HOP_LEN = 256
DURA = 29.12

src, sr = librosa.load(audio_path, sr=SR)
n_sample = src.shape[0]
n_sample_wanted = int(DURA * SR)

# trim the signal at the center
if n_sample < n_sample_wanted: # if too short
src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,))))
elif n_sample > n_sample_wanted: # if too long
src = src[(n_sample - n_sample_wanted) / 2:
(n_sample + n_sample_wanted) / 2]

logam = librosa.logamplitude
melgram = librosa.feature.melspectrogram
x = logam(melgram(y=src, sr=SR, hop_length=HOP_LEN,
n_fft=N_FFT, n_mels=N_MELS) ** 2,
ref_power=1.0)

if dim_ordering == 'th':
x = np.expand_dims(x, axis=0)
elif dim_ordering == 'tf':
x = np.expand_dims(x, axis=3)
return x


def decode_predictions(preds, top_n=5):
'''Decode the output of a music tagger model.
# Arguments
preds: 2-dimensional numpy array
top_n: integer in [0, 50], number of items to show
'''
assert len(preds.shape) == 2 and preds.shape[1] == 50
results = []
for pred in preds:
result = zip(TAGS, pred)
result = sorted(result, key=lambda x: x[1], reverse=True)
results.append(result[:top_n])
return results
162 changes: 162 additions & 0 deletions music_tagger_crnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-
'''MusicTaggerCRNN model for Keras.
Code by github.com/keunwoochoi.
# Reference:
- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras)
'''
from __future__ import print_function
from __future__ import absolute_import

import numpy as np
from keras import backend as K
from keras.layers import Input, Dense
from keras.models import Model
from keras.layers import Dense, Dropout, Reshape, Permute
from keras.layers.convolutional import Convolution2D
from keras.layers.convolutional import MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import ELU
from keras.layers.recurrent import GRU
from keras.utils.data_utils import get_file
from keras.utils.layer_utils import convert_all_kernels_in_model
from audio_conv_utils import decode_predictions, preprocess_input

TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5'
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5'


def MusicTaggerCRNN(weights='msd', input_tensor=None,
include_top=True):
'''Instantiate the MusicTaggerCRNN architecture,
optionally loading weights pre-trained
on Million Song Dataset. Note that when using TensorFlow,
for best performance you should set
`image_dim_ordering="tf"` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
TensorFlow and Theano. The dimension ordering
convention used by the model is the one
specified in your Keras config file.
For preparing mel-spectrogram input, see
`audio_conv_utils.py` in [applications](https://github.com/fchollet/keras/tree/master/keras/applications).
You will need to install [Librosa](http://librosa.github.io/librosa/)
to use it.
# Arguments
weights: one of `None` (random initialization)
or "msd" (pre-training on ImageNet).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
include_top: whether to include the 1 fully-connected
layer (output layer) at the top of the network.
If False, the network outputs 32-dim features.
# Returns
A Keras model instance.
'''
if weights not in {'msd', None}:
raise ValueError('The `weights` argument should be either '
'`None` (random initialization) or `msd` '
'(pre-training on Million Song Dataset).')

# Determine proper input shape
if K.image_dim_ordering() == 'th':
input_shape = (1, 96, 1366)
else:
input_shape = (96, 1366, 1)

if input_tensor is None:
melgram_input = Input(shape=input_shape)
else:
if not K.is_keras_tensor(input_tensor):
melgram_input = Input(tensor=input_tensor, shape=input_shape)
else:
melgram_input = input_tensor

# Determine input axis
if K.image_dim_ordering() == 'th':
channel_axis = 1
freq_axis = 2
time_axis = 3
else:
channel_axis = 3
freq_axis = 1
time_axis = 2

# Input block
x = ZeroPadding2D(padding=(0, 37))(melgram_input)
x = BatchNormalization(axis=time_axis, name='bn_0_freq')(x)

# Conv block 1
x = Convolution2D(64, 3, 3, border_mode='same', name='conv1')(x)
x = BatchNormalization(axis=channel_axis, mode=0, name='bn1')(x)
x = ELU()(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(x)

# Conv block 2
x = Convolution2D(128, 3, 3, border_mode='same', name='conv2')(x)
x = BatchNormalization(axis=channel_axis, mode=0, name='bn2')(x)
x = ELU()(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name='pool2')(x)

# Conv block 3
x = Convolution2D(128, 3, 3, border_mode='same', name='conv3')(x)
x = BatchNormalization(axis=channel_axis, mode=0, name='bn3')(x)
x = ELU()(x)
x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool3')(x)

# Conv block 4
x = Convolution2D(128, 3, 3, border_mode='same', name='conv4')(x)
x = BatchNormalization(axis=channel_axis, mode=0, name='bn4')(x)
x = ELU()(x)
x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool4')(x)

# reshaping
if K.image_dim_ordering() == 'th':
x = Permute((3, 1, 2))(x)
x = Reshape((15, 128))(x)

# GRU block 1, 2, output
x = GRU(32, return_sequences=True, name='gru1')(x)
x = GRU(32, return_sequences=False, name='gru2')(x)

if include_top:
x = Dense(50, activation='sigmoid', name='output')(x)

# Create model
model = Model(melgram_input, x)
if weights is None:
return model
else:
# Load weights
if K.image_dim_ordering() == 'tf':
weights_path = get_file('music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5',
TF_WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5',
TH_WEIGHTS_PATH,
cache_subdir='models')
model.load_weights(weights_path, by_name=True)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)
return model


if __name__ == '__main__':
model = MusicTaggerCRNN(weights='msd')

audio_path = 'audio_file.mp3'
melgram = preprocess_input(audio_path)
melgrams = np.expand_dims(melgram, axis=0)

preds = model.predict(melgrams)
print('Predicted:')
print(decode_predictions(preds))

0 comments on commit cc06786

Please sign in to comment.