Skip to content

Commit

Permalink
Merge pull request #3 from rauldiaz/include-top
Browse files Browse the repository at this point in the history
Include top
  • Loading branch information
Alex Barnes authored Aug 2, 2019
2 parents 23c111e + 1f327a5 commit 91be6b5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 23 deletions.
78 changes: 58 additions & 20 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@

from keras import backend as K
from keras.layers import Input, Dropout, Flatten, Dense, MaxPooling2D, Dot, Lambda, \
Reshape, BatchNormalization, Activation, Conv1D
Reshape, BatchNormalization, Activation, Conv1D, AveragePooling2D
from keras.initializers import Constant
from keras.models import Model
from keras.regularizers import Regularizer
import keras.utils as keras_utils

import numpy as np


WEIGHTS_PATH = ('https://github.com/HPInc/pointnet-keras/'
'releases/download/v1.0/'
'pointnet_modelnet_weights_tf_dim_ordering_tf_kernels.h5')


class OrthogonalRegularizer(Regularizer):
"""
Considering that input is flattened square matrix X, regularizer tries to ensure that matrix X
Expand Down Expand Up @@ -140,12 +146,26 @@ def pointnet_base(inputs):
return net


def pointnet_cls(input_shape, classes, activation=None):
def pointnet_cls(include_top=True, weights=None, input_shape=(2048, 3), pooling=None, classes=40, activation=None):
"""
PointNet model for object classification
:param include_top: whether to include the stack of fully connected layers
:param weights: one of `None` (random initialization),
'modelnet' (pre-training on ModelNet),
or the path to the weights file to be loaded.
:param input_shape: shape of the input point clouds (NxK)
:param pooling: Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 2D tensor output of the last convolutional block (Nx1024).
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a 1D tensor of size 1024.
- `max` means that global max pooling will
be applied.
:param classes: number of classes in the classification problem; if dict, construct multiple disjoint top layers
:param activation: activation of the last layer
:param activation: activation of the last layer (default None).
:return: Keras model of the classification network
"""

Expand All @@ -156,26 +176,44 @@ def pointnet_cls(input_shape, classes, activation=None):
inputs = Input(input_shape, name='Input_cloud')
net = pointnet_base(inputs)

# Symmetric function: max pooling
# Done in 2D since 1D is painfully slow
net = MaxPooling2D(pool_size=(num_point, 1), padding='valid', name='maxpool')(Lambda(K.expand_dims)(net))
net = Flatten()(net)

# Top layers
if isinstance(classes, dict):
# Fully connected layers
net = [dense_bn(net, units=512, scope=r + '_fc1', activation='relu') for r in classes]
net = [Dropout(0.3, name=r + '_dp1')(n) for r, n in zip(classes, net)]
net = [dense_bn(n, units=256, scope=r + '_fc2', activation='relu') for r, n in zip(classes, net)]
net = [Dropout(0.3, name=r + '_dp2')(n) for r, n in zip(classes, net)]
net = [Dense(units=classes[r], activation=activation, name=r)(n) for r, n in zip(classes, net)]
if include_top:
# Symmetric function: max pooling
# Done in 2D since 1D is painfully slow
net = MaxPooling2D(pool_size=(num_point, 1), padding='valid', name='maxpool')(Lambda(K.expand_dims)(net))
net = Flatten()(net)
if isinstance(classes, dict):
# Disjoint stacks of fc layers, one per value in dict
net = [dense_bn(net, units=512, scope=r + '_fc1', activation='relu') for r in classes]
net = [Dropout(0.3, name=r + '_dp1')(n) for r, n in zip(classes, net)]
net = [dense_bn(n, units=256, scope=r + '_fc2', activation='relu') for r, n in zip(classes, net)]
net = [Dropout(0.3, name=r + '_dp2')(n) for r, n in zip(classes, net)]
net = [Dense(units=classes[r], activation=activation, name=r)(n) for r, n in zip(classes, net)]
else:
# Fully connected layers for a single classification task
net = dense_bn(net, units=512, scope='fc1', activation='relu')
net = Dropout(0.3, name='dp1')(net)
net = dense_bn(net, units=256, scope='fc2', activation='relu')
net = Dropout(0.3, name='dp2')(net)
net = Dense(units=classes, name='fc3', activation=activation)(net)
else:
net = dense_bn(net, units=512, scope='fc1', activation='relu')
net = Dropout(0.3, name='dp1')(net)
net = dense_bn(net, units=256, scope='fc2', activation='relu')
net = Dropout(0.3, name='dp2')(net)
net = Dense(units=classes, name='fc3', activation=activation)(net)
if pooling == 'avg':
net = MaxPooling2D(pool_size=(num_point, 1), padding='valid', name='maxpool')(Lambda(K.expand_dims)(net))
elif pooling == 'max':
net = AveragePooling2D(pool_size=(num_point, 1), padding='valid', name='avgpool')(Lambda(K.expand_dims)(net))

model = Model(inputs, net, name='pointnet_cls')

# Load weights.
if weights == 'modelnet':
weights_path = keras_utils.get_file(
'pointnet_modelnet_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models')
model.load_weights(weights_path, by_name=True)
if K.backend() == 'theano':
keras_utils.convert_all_kernels_in_model(model)
elif weights is not None:
model.load_weights(weights, by_name=True)

return model
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
batch_size = 32
num_classes = 40

model = pointnet_cls((input_size, 3), classes=num_classes, activation='softmax')
model = pointnet_cls(input_shape=(input_size, 3), classes=num_classes, activation='softmax')
loss = 'sparse_categorical_crossentropy'
metric = ['sparse_categorical_accuracy']
monitor = 'val_loss'
Expand All @@ -39,7 +39,7 @@
callbacks = list()
callbacks.append(keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True))
callbacks.append(
keras.callbacks.ReduceLROnPlateau(monitor=monitor, factor=0.5, patience=3, verbose=1, min_lr=1e-10))
keras.callbacks.ReduceLROnPlateau(monitor=monitor, factor=0.5, patience=5, verbose=1, min_lr=1e-10))
callbacks.append(keras.callbacks.EarlyStopping(monitor=monitor, patience=10))
callbacks.append(keras.callbacks.ModelCheckpoint(weights_path, monitor=monitor, verbose=0, save_best_only=True,
save_weights_only=True, mode='auto', period=1))
Expand All @@ -53,7 +53,7 @@
val_generator = val_dataset.generate_samples(batch_size=batch_size, augmentation=False)
val_steps_per_epoch = (val_dataset.x.shape[0] // batch_size) + 1

optimizer = adam(lr=1e-3)
optimizer = adam(lr=3e-4)
model.compile(loss=loss, optimizer=optimizer, metrics=metric)

# train
Expand Down

0 comments on commit 91be6b5

Please sign in to comment.