-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of https://github.com/fchollet/deep-learning-mo…
- Loading branch information
Showing
3 changed files
with
251 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import numpy as np | ||
from keras import backend as K | ||
|
||
|
||
TAGS = ['rock', 'pop', 'alternative', 'indie', 'electronic', | ||
'female vocalists', 'dance', '00s', 'alternative rock', 'jazz', | ||
'beautiful', 'metal', 'chillout', 'male vocalists', | ||
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', | ||
'80s', 'folk', '90s', 'chill', 'instrumental', 'punk', | ||
'oldies', 'blues', 'hard rock', 'ambient', 'acoustic', | ||
'experimental', 'female vocalist', 'guitar', 'Hip-Hop', | ||
'70s', 'party', 'country', 'easy listening', | ||
'sexy', 'catchy', 'funk', 'electro', 'heavy metal', | ||
'Progressive rock', '60s', 'rnb', 'indie pop', | ||
'sad', 'House', 'happy'] | ||
|
||
|
||
def librosa_exists(): | ||
try: | ||
__import__('librosa') | ||
except ImportError: | ||
return False | ||
else: | ||
return True | ||
|
||
|
||
def preprocess_input(audio_path, dim_ordering='default'): | ||
'''Reads an audio file and outputs a Mel-spectrogram. | ||
''' | ||
if dim_ordering == 'default': | ||
dim_ordering = K.image_dim_ordering() | ||
assert dim_ordering in {'tf', 'th'} | ||
|
||
if librosa_exists(): | ||
import librosa | ||
else: | ||
raise RuntimeError('Librosa is required to process audio files.\n' + | ||
'Install it via `pip install librosa` \nor visit ' + | ||
'http://librosa.github.io/librosa/ for details.') | ||
|
||
# mel-spectrogram parameters | ||
SR = 12000 | ||
N_FFT = 512 | ||
N_MELS = 96 | ||
HOP_LEN = 256 | ||
DURA = 29.12 | ||
|
||
src, sr = librosa.load(audio_path, sr=SR) | ||
n_sample = src.shape[0] | ||
n_sample_wanted = int(DURA * SR) | ||
|
||
# trim the signal at the center | ||
if n_sample < n_sample_wanted: # if too short | ||
src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,)))) | ||
elif n_sample > n_sample_wanted: # if too long | ||
src = src[(n_sample - n_sample_wanted) / 2: | ||
(n_sample + n_sample_wanted) / 2] | ||
|
||
logam = librosa.logamplitude | ||
melgram = librosa.feature.melspectrogram | ||
x = logam(melgram(y=src, sr=SR, hop_length=HOP_LEN, | ||
n_fft=N_FFT, n_mels=N_MELS) ** 2, | ||
ref_power=1.0) | ||
|
||
if dim_ordering == 'th': | ||
x = np.expand_dims(x, axis=0) | ||
elif dim_ordering == 'tf': | ||
x = np.expand_dims(x, axis=3) | ||
return x | ||
|
||
|
||
def decode_predictions(preds, top_n=5): | ||
'''Decode the output of a music tagger model. | ||
# Arguments | ||
preds: 2-dimensional numpy array | ||
top_n: integer in [0, 50], number of items to show | ||
''' | ||
assert len(preds.shape) == 2 and preds.shape[1] == 50 | ||
results = [] | ||
for pred in preds: | ||
result = zip(TAGS, pred) | ||
result = sorted(result, key=lambda x: x[1], reverse=True) | ||
results.append(result[:top_n]) | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# -*- coding: utf-8 -*- | ||
'''MusicTaggerCRNN model for Keras. | ||
Code by github.com/keunwoochoi. | ||
# Reference: | ||
- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras) | ||
''' | ||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
|
||
import numpy as np | ||
from keras import backend as K | ||
from keras.layers import Input, Dense | ||
from keras.models import Model | ||
from keras.layers import Dense, Dropout, Reshape, Permute | ||
from keras.layers.convolutional import Convolution2D | ||
from keras.layers.convolutional import MaxPooling2D, ZeroPadding2D | ||
from keras.layers.normalization import BatchNormalization | ||
from keras.layers.advanced_activations import ELU | ||
from keras.layers.recurrent import GRU | ||
from keras.utils.data_utils import get_file | ||
from keras.utils.layer_utils import convert_all_kernels_in_model | ||
from audio_conv_utils import decode_predictions, preprocess_input | ||
|
||
TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5' | ||
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5' | ||
|
||
|
||
def MusicTaggerCRNN(weights='msd', input_tensor=None, | ||
include_top=True): | ||
'''Instantiate the MusicTaggerCRNN architecture, | ||
optionally loading weights pre-trained | ||
on Million Song Dataset. Note that when using TensorFlow, | ||
for best performance you should set | ||
`image_dim_ordering="tf"` in your Keras config | ||
at ~/.keras/keras.json. | ||
The model and the weights are compatible with both | ||
TensorFlow and Theano. The dimension ordering | ||
convention used by the model is the one | ||
specified in your Keras config file. | ||
For preparing mel-spectrogram input, see | ||
`audio_conv_utils.py` in [applications](https://github.com/fchollet/keras/tree/master/keras/applications). | ||
You will need to install [Librosa](http://librosa.github.io/librosa/) | ||
to use it. | ||
# Arguments | ||
weights: one of `None` (random initialization) | ||
or "msd" (pre-training on ImageNet). | ||
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) | ||
to use as image input for the model. | ||
include_top: whether to include the 1 fully-connected | ||
layer (output layer) at the top of the network. | ||
If False, the network outputs 32-dim features. | ||
# Returns | ||
A Keras model instance. | ||
''' | ||
if weights not in {'msd', None}: | ||
raise ValueError('The `weights` argument should be either ' | ||
'`None` (random initialization) or `msd` ' | ||
'(pre-training on Million Song Dataset).') | ||
|
||
# Determine proper input shape | ||
if K.image_dim_ordering() == 'th': | ||
input_shape = (1, 96, 1366) | ||
else: | ||
input_shape = (96, 1366, 1) | ||
|
||
if input_tensor is None: | ||
melgram_input = Input(shape=input_shape) | ||
else: | ||
if not K.is_keras_tensor(input_tensor): | ||
melgram_input = Input(tensor=input_tensor, shape=input_shape) | ||
else: | ||
melgram_input = input_tensor | ||
|
||
# Determine input axis | ||
if K.image_dim_ordering() == 'th': | ||
channel_axis = 1 | ||
freq_axis = 2 | ||
time_axis = 3 | ||
else: | ||
channel_axis = 3 | ||
freq_axis = 1 | ||
time_axis = 2 | ||
|
||
# Input block | ||
x = ZeroPadding2D(padding=(0, 37))(melgram_input) | ||
x = BatchNormalization(axis=time_axis, name='bn_0_freq')(x) | ||
|
||
# Conv block 1 | ||
x = Convolution2D(64, 3, 3, border_mode='same', name='conv1')(x) | ||
x = BatchNormalization(axis=channel_axis, mode=0, name='bn1')(x) | ||
x = ELU()(x) | ||
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(x) | ||
|
||
# Conv block 2 | ||
x = Convolution2D(128, 3, 3, border_mode='same', name='conv2')(x) | ||
x = BatchNormalization(axis=channel_axis, mode=0, name='bn2')(x) | ||
x = ELU()(x) | ||
x = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name='pool2')(x) | ||
|
||
# Conv block 3 | ||
x = Convolution2D(128, 3, 3, border_mode='same', name='conv3')(x) | ||
x = BatchNormalization(axis=channel_axis, mode=0, name='bn3')(x) | ||
x = ELU()(x) | ||
x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool3')(x) | ||
|
||
# Conv block 4 | ||
x = Convolution2D(128, 3, 3, border_mode='same', name='conv4')(x) | ||
x = BatchNormalization(axis=channel_axis, mode=0, name='bn4')(x) | ||
x = ELU()(x) | ||
x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool4')(x) | ||
|
||
# reshaping | ||
if K.image_dim_ordering() == 'th': | ||
x = Permute((3, 1, 2))(x) | ||
x = Reshape((15, 128))(x) | ||
|
||
# GRU block 1, 2, output | ||
x = GRU(32, return_sequences=True, name='gru1')(x) | ||
x = GRU(32, return_sequences=False, name='gru2')(x) | ||
|
||
if include_top: | ||
x = Dense(50, activation='sigmoid', name='output')(x) | ||
|
||
# Create model | ||
model = Model(melgram_input, x) | ||
if weights is None: | ||
return model | ||
else: | ||
# Load weights | ||
if K.image_dim_ordering() == 'tf': | ||
weights_path = get_file('music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5', | ||
TF_WEIGHTS_PATH, | ||
cache_subdir='models') | ||
else: | ||
weights_path = get_file('music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5', | ||
TH_WEIGHTS_PATH, | ||
cache_subdir='models') | ||
model.load_weights(weights_path, by_name=True) | ||
if K.backend() == 'theano': | ||
convert_all_kernels_in_model(model) | ||
return model | ||
|
||
|
||
if __name__ == '__main__': | ||
model = MusicTaggerCRNN(weights='msd') | ||
|
||
audio_path = 'audio_file.mp3' | ||
melgram = preprocess_input(audio_path) | ||
melgrams = np.expand_dims(melgram, axis=0) | ||
|
||
preds = model.predict(melgrams) | ||
print('Predicted:') | ||
print(decode_predictions(preds)) |