forked from titu1994/DenseNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
densenet_fast.py
145 lines (103 loc) · 4.98 KB
/
densenet_fast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.convolutional import Convolution2D
from keras.layers.pooling import AveragePooling2D
from keras.layers.pooling import GlobalAveragePooling2D
from keras.layers import Input, merge
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
import keras.backend as K
'''
Based on the implementation here : https://github.com/Lasagne/Recipes/blob/master/papers/densenet/densenet_fast.py
'''
def conv_block(ip, nb_filter, dropout_rate=None, weight_decay=1E-4):
''' Apply BatchNorm, Relu 3x3, Conv2D, optional dropout
Args:
ip: Input keras tensor
nb_filter: number of filters
dropout_rate: dropout rate
weight_decay: weight decay factor
Returns: keras tensor with batch_norm, relu and convolution2d added
'''
x = Activation('relu')(ip)
x = Convolution2D(nb_filter, 3, 3, init="he_uniform", border_mode="same", bias=False,
W_regularizer=l2(weight_decay))(x)
if dropout_rate:
x = Dropout(dropout_rate)(x)
return x
def transition_block(ip, nb_filter, dropout_rate=None, weight_decay=1E-4):
''' Apply BatchNorm, Relu 1x1, Conv2D, optional dropout and Maxpooling2D
Args:
ip: keras tensor
nb_filter: number of filters
dropout_rate: dropout rate
weight_decay: weight decay factor
Returns: keras tensor, after applying batch_norm, relu-conv, dropout, maxpool
'''
concat_axis = 1 if K.image_dim_ordering() == "th" else -1
x = Convolution2D(nb_filter, 1, 1, init="he_uniform", border_mode="same", bias=False,
W_regularizer=l2(weight_decay))(ip)
if dropout_rate:
x = Dropout(dropout_rate)(x)
x = AveragePooling2D((2, 2), strides=(2, 2))(x)
x = BatchNormalization(mode=0, axis=concat_axis, gamma_regularizer=l2(weight_decay),
beta_regularizer=l2(weight_decay))(x)
return x
def dense_block(x, nb_layers, nb_filter, growth_rate, dropout_rate=None, weight_decay=1E-4):
''' Build a dense_block where the output of each conv_block is fed to subsequent ones
Args:
x: keras tensor
nb_layers: the number of layers of conv_block to append to the model.
nb_filter: number of filters
growth_rate: growth rate
dropout_rate: dropout rate
weight_decay: weight decay factor
Returns: keras tensor with nb_layers of conv_block appended
'''
concat_axis = 1 if K.image_dim_ordering() == "th" else -1
feature_list = [x]
for i in range(nb_layers):
x = conv_block(x, growth_rate, dropout_rate, weight_decay)
feature_list.append(x)
x = merge(feature_list, mode='concat', concat_axis=concat_axis)
nb_filter += growth_rate
return x, nb_filter
def create_dense_net(nb_classes, img_dim, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=16, dropout_rate=None,
weight_decay=1E-4, verbose=True):
''' Build the create_dense_net model
Args:
nb_classes: number of classes
img_dim: tuple of shape (channels, rows, columns) or (rows, columns, channels)
depth: number or layers
nb_dense_block: number of dense blocks to add to end
growth_rate: number of filters to add
nb_filter: number of filters
dropout_rate: dropout rate
weight_decay: weight decay
Returns: keras tensor with nb_layers of conv_block appended
'''
model_input = Input(shape=img_dim)
concat_axis = 1 if K.image_dim_ordering() == "th" else -1
assert (depth - 4) % 3 == 0, "Depth must be 3 N + 4"
# layers in each dense block
nb_layers = int((depth - 4) / 3)
# Initial convolution
x = Convolution2D(nb_filter, 3, 3, init="he_uniform", border_mode="same", name="initial_conv2D", bias=False,
W_regularizer=l2(weight_decay))(model_input)
x = BatchNormalization(mode=0, axis=concat_axis, gamma_regularizer=l2(weight_decay),
beta_regularizer=l2(weight_decay))(x)
# Add dense blocks
for block_idx in range(nb_dense_block - 1):
x, nb_filter = dense_block(x, nb_layers, nb_filter, growth_rate, dropout_rate=dropout_rate,
weight_decay=weight_decay)
# add transition_block
x = transition_block(x, nb_filter, dropout_rate=dropout_rate, weight_decay=weight_decay)
# The last dense_block does not have a transition_block
x, nb_filter = dense_block(x, nb_layers, nb_filter, growth_rate, dropout_rate=dropout_rate,
weight_decay=weight_decay)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(nb_classes, activation='softmax', W_regularizer=l2(weight_decay), b_regularizer=l2(weight_decay))(x)
densenet = Model(input=model_input, output=x, name="create_dense_net")
if verbose: print("DenseNet-%d-%d created." % (depth, growth_rate))
return densenet