From dd163a52f632fa6fa3fa97cd965692145789a6cb Mon Sep 17 00:00:00 2001 From: namish800 Date: Thu, 1 Nov 2018 14:49:10 +0530 Subject: [PATCH] *Removed package.json and package-lock.json *Added info for supported frameworks. *Added Length layer *Added MaskCapsule layer *Added Capsule layer *Added model to the modelZoo.js --- README.md | 1 + example/keras/Capsnet.json | 414 +++++++++++++++++++++++ ide/static/js/data.js | 120 +++++++ ide/static/js/modelZoo.js | 5 +- ide/static/js/pane.js | 12 + keras_app/custom_layers/capsule_layer.py | 60 ++++ keras_app/custom_layers/config.py | 16 + keras_app/custom_layers/length.py | 20 ++ keras_app/custom_layers/mask_capsule.py | 29 ++ keras_app/custom_layers/squash.py | 24 ++ keras_app/views/export_json.py | 8 +- keras_app/views/import_json.py | 16 +- keras_app/views/layers_export.py | 33 ++ keras_app/views/layers_import.py | 25 ++ 14 files changed, 777 insertions(+), 6 deletions(-) create mode 100644 example/keras/Capsnet.json create mode 100644 keras_app/custom_layers/capsule_layer.py create mode 100644 keras_app/custom_layers/length.py create mode 100644 keras_app/custom_layers/mask_capsule.py create mode 100644 keras_app/custom_layers/squash.py diff --git a/README.md b/README.md index 50db680a2..8fc20e8b5 100644 --- a/README.md +++ b/README.md @@ -291,6 +291,7 @@ Models | Ca [Pix2Pix](https://github.com/phillipi/pix2pix) | √ | × | × | [VQA](https://github.com/iamaaditya/VQA_Demo) | √ | √ | √ | [Denoising Auto-Encoder](https://blog.keras.io/building-autoencoders-in-keras.html) | × | √ | √ | +[CapsNet](https://arxiv.org/abs/1710.09829) | × | √ | √ | Note: For models that use a custom LRN layer (Alexnet), Keras expects the custom layer to be passed when it is loaded from json. LRN.py is located in keras_app/custom_layers. [Alexnet import for Keras](https://github.com/Cloud-CV/Fabrik/blob/master/tutorials/keras_custom_layer_usage.md) diff --git a/example/keras/Capsnet.json b/example/keras/Capsnet.json new file mode 100644 index 000000000..d9ceda43f --- /dev/null +++ b/example/keras/Capsnet.json @@ -0,0 +1,414 @@ +{ + "class_name": "Model", + "keras_version": "2.0.8", + "config": { + "layers": [ + { + "class_name": "InputLayer", + "config": { + "dtype": "float32", + "batch_input_shape": [ + null, + 28, + 28, + 1 + ], + "name": "Input", + "sparse": false + }, + "inbound_nodes": [], + "name": "Input" + }, + { + "class_name": "Conv2D", + "config": { + "kernel_constraint": null, + "kernel_initializer": { + "class_name": "VarianceScaling", + "config": { + "distribution": "uniform", + "scale": 1.0, + "seed": null, + "mode": "fan_avg" + } + }, + "name": "conv1", + "bias_regularizer": null, + "bias_constraint": null, + "activation": "relu", + "trainable": true, + "data_format": "channels_last", + "padding": "valid", + "strides": [ + 1, + 1 + ], + "dilation_rate": [ + 1, + 1 + ], + "kernel_regularizer": null, + "filters": 256, + "bias_initializer": { + "class_name": "Zeros", + "config": {} + }, + "use_bias": true, + "activity_regularizer": null, + "kernel_size": [ + 9, + 9 + ] + }, + "inbound_nodes": [ + [ + [ + "Input", + 0, + 0, + {} + ] + ] + ], + "name": "conv1" + }, + { + "class_name": "Conv2D", + "config": { + "kernel_constraint": null, + "kernel_initializer": { + "class_name": "VarianceScaling", + "config": { + "distribution": "uniform", + "scale": 1.0, + "seed": null, + "mode": "fan_avg" + } + }, + "name": "primarycap_conv2d", + "bias_regularizer": null, + "bias_constraint": null, + "activation": "linear", + "trainable": true, + "data_format": "channels_last", + "padding": "valid", + "strides": [ + 2, + 2 + ], + "dilation_rate": [ + 1, + 1 + ], + "kernel_regularizer": null, + "filters": 256, + "bias_initializer": { + "class_name": "Zeros", + "config": {} + }, + "use_bias": true, + "activity_regularizer": null, + "kernel_size": [ + 9, + 9 + ] + }, + "inbound_nodes": [ + [ + [ + "conv1", + 0, + 0, + {} + ] + ] + ], + "name": "primarycap_conv2d" + }, + { + "class_name": "Reshape", + "config": { + "target_shape": [ + 1152, + 8 + ], + "trainable": true, + "name": "primarycap_reshape" + }, + "inbound_nodes": [ + [ + [ + "primarycap_conv2d", + 0, + 0, + {} + ] + ] + ], + "name": "primarycap_reshape" + }, + { + "class_name": "Squash", + "config": { + "trainable": true, + "name": "primarycap_sqash", + "axis": -1 + }, + "inbound_nodes": [ + [ + [ + "primarycap_reshape", + 0, + 0, + {} + ] + ] + ], + "name": "primarycap_sqash" + }, + { + "class_name": "CapsuleLayer", + "config": { + "num_routing": 3, + "dim_capsule": 16, + "trainable": true, + "name": "digitcaps", + "num_capsule": 10 + }, + "inbound_nodes": [ + [ + [ + "primarycap_sqash", + 0, + 0, + {} + ] + ] + ], + "name": "digitcaps" + }, + { + "class_name": "InputLayer", + "config": { + "dtype": "float32", + "batch_input_shape": [ + null, + 10 + ], + "name": "input_1", + "sparse": false + }, + "inbound_nodes": [], + "name": "input_1" + }, + { + "class_name": "MaskCapsule", + "config": { + "trainable": true, + "name": "mask_capsule_1" + }, + "inbound_nodes": [ + [ + [ + "digitcaps", + 0, + 0, + {} + ], + [ + "input_1", + 0, + 0, + {} + ] + ] + ], + "name": "mask_capsule_1" + }, + { + "class_name": "Dense", + "config": { + "kernel_initializer": { + "class_name": "VarianceScaling", + "config": { + "distribution": "uniform", + "scale": 1.0, + "seed": null, + "mode": "fan_avg" + } + }, + "name": "dense_1", + "kernel_constraint": null, + "bias_regularizer": null, + "bias_constraint": null, + "activation": "relu", + "trainable": true, + "kernel_regularizer": null, + "bias_initializer": { + "class_name": "Zeros", + "config": {} + }, + "units": 512, + "use_bias": true, + "activity_regularizer": null + }, + "inbound_nodes": [ + [ + [ + "mask_capsule_1", + 0, + 0, + {} + ] + ] + ], + "name": "dense_1" + }, + { + "class_name": "Dense", + "config": { + "kernel_initializer": { + "class_name": "VarianceScaling", + "config": { + "distribution": "uniform", + "scale": 1.0, + "seed": null, + "mode": "fan_avg" + } + }, + "name": "dense_2", + "kernel_constraint": null, + "bias_regularizer": null, + "bias_constraint": null, + "activation": "relu", + "trainable": true, + "kernel_regularizer": null, + "bias_initializer": { + "class_name": "Zeros", + "config": {} + }, + "units": 1024, + "use_bias": true, + "activity_regularizer": null + }, + "inbound_nodes": [ + [ + [ + "dense_1", + 0, + 0, + {} + ] + ] + ], + "name": "dense_2" + }, + { + "class_name": "Dense", + "config": { + "kernel_initializer": { + "class_name": "VarianceScaling", + "config": { + "distribution": "uniform", + "scale": 1.0, + "seed": null, + "mode": "fan_avg" + } + }, + "name": "dense_3", + "kernel_constraint": null, + "bias_regularizer": null, + "bias_constraint": null, + "activation": "sigmoid", + "trainable": true, + "kernel_regularizer": null, + "bias_initializer": { + "class_name": "Zeros", + "config": {} + }, + "units": 784, + "use_bias": true, + "activity_regularizer": null + }, + "inbound_nodes": [ + [ + [ + "dense_2", + 0, + 0, + {} + ] + ] + ], + "name": "dense_3" + }, + { + "class_name": "Length", + "config": { + "trainable": true, + "name": "capsnet" + }, + "inbound_nodes": [ + [ + [ + "digitcaps", + 0, + 0, + {} + ] + ] + ], + "name": "capsnet" + }, + { + "class_name": "Reshape", + "config": { + "target_shape": [ + 28, + 28, + 1 + ], + "trainable": true, + "name": "out_recon" + }, + "inbound_nodes": [ + [ + [ + "dense_3", + 0, + 0, + {} + ] + ] + ], + "name": "out_recon" + } + ], + "input_layers": [ + [ + "Input", + 0, + 0 + ], + [ + "input_1", + 0, + 0 + ] + ], + "output_layers": [ + [ + "capsnet", + 0, + 0 + ], + [ + "out_recon", + 0, + 0 + ] + ], + "name": "model_1" + }, + "backend": "tensorflow" +} \ No newline at end of file diff --git a/ide/static/js/data.js b/ide/static/js/data.js index b9bb5d410..3bb5ec1a1 100644 --- a/ide/static/js/data.js +++ b/ide/static/js/data.js @@ -1343,6 +1343,48 @@ export default { }, learn: true }, + CapsuleLayer: { + name:'capsule_layer', + color:'#ffeb3b', + endpoint:{ + src:['Bottom'], + trg:['Top'] + }, + params: { + num_capsule: { + name:'No. of Capsules', + value: '', + type: 'number', + required: true + }, + dim_capsule: { + name:'Dimension of the Capsule', + value: '', + type: 'number', + required: true + }, + num_routing: { + name: 'No. of routings', + value: '', + type: 'number', + required: true + }, + caffe: { + name: 'Available Caffe', + value: false, + type: 'checkbox', + required: false + } + }, + props: { + name: { + name:'Name', + value: '', + type: 'text' + } + }, + learn: true + }, /* ********** Recurrent Layers ********** */ Recurrent: { // Only Caffe name: 'recurrent', @@ -3065,6 +3107,36 @@ export default { }, learn: true }, + Squash: { + name: 'squash', + color: '#0329f4', + endpoint: { + src: ['Bottom'], + trg: ['Top'] + }, + params: { + axis: { + name: 'Axis', + value: -1, + type: 'number', + required:false + }, + caffe: { + name: 'Available Caffe', + value: false, + type: 'checkbox', + required: false + } + }, + props:{ + name: { + name:'Name', + value: '', + type: 'text' + } + }, + learn: false + }, /* ********** Utility Layers ********** */ Flatten: { name: 'flatten', @@ -3603,6 +3675,54 @@ export default { }, learn: false }, + Length: { + name:'length', + color: '#ffeb3c', + endpoint: { + src:['Bottom'], + trg:['Top'] + }, + params: { + caffe: { + name:'Available Caffe', + value: false, + type: 'checkbox', + required: false + } + }, + props: { + name: { + name: 'Name', + value: '', + type: 'text' + } + }, + learn: false + }, + MaskCapsule: { + name:'mask_capsule', + color: '#ff98c0', + endpoint: { + src:['Bottom'], + trg:['Top'] + }, + params: { + caffe: { + name:'Available Caffe', + value: false, + type: 'checkbox', + required: false + } + }, + props:{ + name: { + name: "Name", + value: '', + type: 'text' + } + }, + learn: false + }, /* ********** Loss Layers ********** */ MultinomialLogisticLoss: { // Only Caffe name: 'multinomial logistic loss', diff --git a/ide/static/js/modelZoo.js b/ide/static/js/modelZoo.js index 320360427..66d2a98bc 100644 --- a/ide/static/js/modelZoo.js +++ b/ide/static/js/modelZoo.js @@ -34,7 +34,10 @@ class ModelZoo extends React.Component {
IMDB CNN LSTM
- SimpleNet + SimpleNet +
+ CapsNet +

Detection

diff --git a/ide/static/js/pane.js b/ide/static/js/pane.js index 2f1f96217..4375b73f9 100644 --- a/ide/static/js/pane.js +++ b/ide/static/js/pane.js @@ -161,6 +161,9 @@ class Pane extends React.Component { Depthwise Convolution + Capsule Layer
@@ -250,6 +253,12 @@ class Pane extends React.Component { Masking + Length + Mask Capsule @@ -319,6 +328,9 @@ class Pane extends React.Component { Scale + Squash diff --git a/keras_app/custom_layers/capsule_layer.py b/keras_app/custom_layers/capsule_layer.py new file mode 100644 index 000000000..28aa6448d --- /dev/null +++ b/keras_app/custom_layers/capsule_layer.py @@ -0,0 +1,60 @@ +from keras.layers.core import Layer +from keras import backend as K +import keras + + +def squash_activation(vectors, axis=-1): + s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True) + scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + 1e-7) + return scale * vectors + + +class CapsuleLayer(Layer): + def __init__(self, num_capsule, dim_capsule, num_routing=3, **kwargs): + super(CapsuleLayer, self).__init__(**kwargs) + self.num_capsule = num_capsule + self.dim_capsule = dim_capsule + self.num_routing = num_routing + self.kernel_initializer = keras.initializers.random_uniform(-1, 1) + self.bias_initializer = keras.initializers.Zeros() + super(CapsuleLayer, self).__init__(**kwargs) + + def build(self, input_shape): + assert len(input_shape) >= 3 + self.W = self.add_weight(shape=[input_shape[1], self.num_capsule, input_shape[2], self.dim_capsule], + initializer=self.kernel_initializer, + name='W') + self.b = self.add_weight(shape=[input_shape[1], self.num_capsule], + initializer=self.bias_initializer, + name='b') + super(CapsuleLayer, self).build(input_shape) + + def call(self, inputs, training=None): + inputs_expand = K.expand_dims(inputs, 2) + inputs_tiled = K.repeat_elements(inputs_expand, self.num_capsule, axis=2) + inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 2]), inputs_tiled) + input_shape = K.shape(inputs_hat) + b = self.b + b = K.expand_dims(b, axis=0) + assert self.num_routing > 0 + for i in range(self.num_routing): + c = K.softmax(b) + c = K.expand_dims(c, axis=-1) + c = K.repeat_elements(c, rep=self.dim_capsule, axis=-1) + S = K.sum(c * inputs_hat, axis=1) + V = squash_activation(S) + if i != self.num_routing-1: + V_expanded = K.expand_dims(V, axis=1) + V_expanded = K.tile(V_expanded, [1, input_shape[1], 1, 1]) + b = b + K.sum(inputs_hat * V_expanded, axis=-1) + return V + + def compute_output_shape(self, input_shape): + return tuple([None, self.num_capsule, self.dim_capsule]) + + def get_config(self): + base_config = super(CapsuleLayer, self).get_config() + base_config['num_capsule'] = self.num_capsule + base_config['num_routing'] = self.num_routing + base_config['dim_capsule'] = self.dim_capsule + return base_config diff --git a/keras_app/custom_layers/config.py b/keras_app/custom_layers/config.py index 1a6afc315..335d786e2 100644 --- a/keras_app/custom_layers/config.py +++ b/keras_app/custom_layers/config.py @@ -4,5 +4,21 @@ 'LRN': { 'filename': 'lrn.py', 'url': '/media/lrn.py' + }, + 'CapsuleLayer': { + 'filename': 'capsule_layer.py', + 'url': '/media/capsule_layer.py' + }, + 'Length': { + 'filename': 'length.py', + 'url': 'media/length.py' + }, + 'MaskCapsule': { + 'filename': 'mask_capsule.py', + 'url': 'media/mask_capsule.py' + }, + 'Squash': { + 'filename': 'squash.py', + 'url': 'media/squash.py' } } diff --git a/keras_app/custom_layers/length.py b/keras_app/custom_layers/length.py new file mode 100644 index 000000000..ae9bf315e --- /dev/null +++ b/keras_app/custom_layers/length.py @@ -0,0 +1,20 @@ +from keras.layers.core import Layer +from keras import backend as K + + +class Length(Layer): + def __init__(self, **kwargs): + super(Length, self).__init__(**kwargs) + + def build(self, input_shape): + super(Length, self).build(input_shape) + + def call(self, input): + return K.sqrt(K.sum(K.square(input), -1)) + + def compute_output_shape(self, input_shape): + return input_shape[:-1] + + def get_config(self): + base_config = super(Length, self).get_config() + return base_config diff --git a/keras_app/custom_layers/mask_capsule.py b/keras_app/custom_layers/mask_capsule.py new file mode 100644 index 000000000..d7f3ce6e3 --- /dev/null +++ b/keras_app/custom_layers/mask_capsule.py @@ -0,0 +1,29 @@ +from keras.layers.core import Layer +from keras import backend as K + + +class MaskCapsule(Layer): + def __init__(self, **kwargs): + super(MaskCapsule, self).__init__(**kwargs) + + def build(self, input_shape): + super(MaskCapsule, self).build(input_shape) + + def call(self, inputs): + if type(inputs) == list: + assert len(inputs) == 2 + inputs, mask = inputs + else: + x = K.sqrt(K.sum(K.square(inputs), -1)) + mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) + masked = K.batch_flatten(inputs*K.expand_dims(mask, -1)) + return masked + + def compute_output_shape(self, input_shape): + if type(input_shape[0]) is tuple: + return tuple([None, input_shape[0][1]*input_shape[0][2]]) + else: + return tuple([None, input_shape[1]*input_shape[2]]) + + def get_config(self): + return super(MaskCapsule, self).get_config() diff --git a/keras_app/custom_layers/squash.py b/keras_app/custom_layers/squash.py new file mode 100644 index 000000000..67cb9d2da --- /dev/null +++ b/keras_app/custom_layers/squash.py @@ -0,0 +1,24 @@ +from keras.layers.core import Layer +from keras import backend as K + + +class Squash(Layer): + def __init__(self, axis=-1, **kwargs): + self.axis = axis + super(Squash, self).__init__(**kwargs) + + def build(self, input_shape): + super(Squash, self).build(input_shape) + + def call(self, inputs): + s_squared_norm = K.sum(K.square(inputs), self.axis, keepdims=True) + scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + 1e-7) + return scale * inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super(Squash, self).get_config() + base_config['axis'] = self.axis + return base_config diff --git a/keras_app/views/export_json.py b/keras_app/views/export_json.py index af8913fea..c0ee6f6ad 100644 --- a/keras_app/views/export_json.py +++ b/keras_app/views/export_json.py @@ -10,7 +10,7 @@ from layers_export import data, convolution, deconvolution, pooling, dense, dropout, embed,\ recurrent, batch_norm, activation, flatten, reshape, eltwise, concat, upsample, locally_connected,\ permute, repeat_vector, regularization, masking, gaussian_noise, gaussian_dropout, alpha_dropout, \ - bidirectional, time_distributed, lrn, depthwiseConv + bidirectional, time_distributed, lrn, depthwiseConv, capsule_layer, length, mask_capsule, squash from ..custom_layers import config as custom_layers_config @@ -81,7 +81,11 @@ def export_json(request, is_tf=False): } custom_layers_map = { - 'LRN': lrn + 'LRN': lrn, + 'CapsuleLayer': capsule_layer, + 'Length': length, + 'MaskCapsule': mask_capsule, + 'Squash': squash } # Remove any duplicate activation layers (timedistributed and bidirectional layers) diff --git a/keras_app/views/import_json.py b/keras_app/views/import_json.py index d8b919884..be29162db 100644 --- a/keras_app/views/import_json.py +++ b/keras_app/views/import_json.py @@ -10,10 +10,14 @@ Recurrent, BatchNorm, Activation, LeakyReLU, PReLU, ELU, Scale, Flatten, Reshape, Concat, \ Eltwise, Padding, Upsample, LocallyConnected, ThresholdedReLU, Permute, RepeatVector,\ ActivityRegularization, Masking, GaussianNoise, GaussianDropout, AlphaDropout, \ - TimeDistributed, Bidirectional, DepthwiseConv, lrn + TimeDistributed, Bidirectional, DepthwiseConv, lrn, capsule_layer, length, mask_capsule, squash from keras.models import model_from_json, Sequential from keras.layers import deserialize from ..custom_layers.lrn import LRN +from ..custom_layers.capsule_layer import CapsuleLayer +from ..custom_layers.length import Length +from ..custom_layers.mask_capsule import MaskCapsule +from ..custom_layers.squash import Squash @csrf_exempt @@ -49,7 +53,9 @@ def import_json(request): except Exception: return JsonResponse({'result': 'error', 'error': 'Invalid JSON'}) - model = model_from_json(json.dumps(model), custom_objects={'LRN': LRN}) + model = model_from_json(json.dumps(model), + custom_objects={'LRN': LRN, 'Length': Length, 'MaskCapsule': MaskCapsule, + 'Squash': Squash, 'CapsuleLayer': CapsuleLayer}) layer_map = { 'InputLayer': Input, 'Dense': Dense, @@ -113,7 +119,11 @@ def import_json(request): 'AlphaDropout': AlphaDropout, 'TimeDistributed': TimeDistributed, 'Bidirectional': Bidirectional, - 'LRN': lrn + 'LRN': lrn, + 'CapsuleLayer': capsule_layer, + 'Length': length, + 'MaskCapsule': mask_capsule, + 'Squash': squash } hasActivation = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'Dense', 'LocallyConnected1D', diff --git a/keras_app/views/layers_export.py b/keras_app/views/layers_export.py index 257065610..f82c3aa6e 100644 --- a/keras_app/views/layers_export.py +++ b/keras_app/views/layers_export.py @@ -18,6 +18,10 @@ from keras.layers import TimeDistributed, Bidirectional from keras import regularizers from ..custom_layers.lrn import LRN +from ..custom_layers.capsule_layer import CapsuleLayer +from ..custom_layers.length import Length +from ..custom_layers.mask_capsule import MaskCapsule +from ..custom_layers.squash import Squash fillerMap = { 'constant': 'Constant', @@ -128,6 +132,12 @@ def activation(layer, layer_in, layerId, tensor=True): return out +def squash(layer, layer_in, layerId): + axis = layer['params']['axis'] + out = {} + out[layerId] = Squash(axis=axis)(*layer_in) + + def dropout(layer, layer_in, layerId, tensor=True): out = {layerId: Dropout(0.5)} if tensor: @@ -180,6 +190,18 @@ def masking(layer, layer_in, layerId, tensor=True): return out +def length(layer, layer_in, layerId): + out = {} + out[layerId] = Length()(*layer_in) + return out + + +def mask_capsule(layer, layer_in, layerId): + out = {} + out[layerId] = MaskCapsule()(*layer_in) + return out + + # ********** Convolution Layers ********** def convolution(layer, layer_in, layerId, tensor=True): convMap = { @@ -366,6 +388,17 @@ def upsample(layer, layer_in, layerId, tensor=True): return out +# ********** Capsule Layers ********** +def capsule_layer(layer, layer_in, layerId): + num_capsule = layer['params']['num_capsule'] + dim_capsule = layer['params']['dim_capsule'] + num_routing = layer['params']['num_routing'] + out = {} + out[layerId] = CapsuleLayer(num_capsule=num_capsule, dim_capsule=dim_capsule, + num_routing=num_routing)(*layer_in) + return out + + # ********** Pooling Layers ********** def pooling(layer, layer_in, layerId, tensor=True): poolMap = { diff --git a/keras_app/views/layers_import.py b/keras_app/views/layers_import.py index 7be30f85c..508238a02 100644 --- a/keras_app/views/layers_import.py +++ b/keras_app/views/layers_import.py @@ -52,6 +52,12 @@ def Activation(layer): return jsonLayer(activationMap[layer.activation.func_name], {}, tempLayer) +def squash(layer): + params = {} + params['axis'] = layer.axis + return jsonLayer('Squash', params, layer) + + def Dropout(layer): params = {} if (layer.rate is not None): @@ -99,6 +105,16 @@ def Masking(layer): return jsonLayer('Masking', params, layer) +def length(layer): + params = {} + return jsonLayer('Length', params, layer) + + +def mask_capsule(layer): + params = {} + return jsonLayer('MaskCapsule', params, layer) + + # ********** Convolutional Layers ********** def Convolution(layer): params = {} @@ -225,6 +241,15 @@ def Upsample(layer): return jsonLayer('Upsample', params, layer) +# ********** Capsule Layers ********** +def capsule_layer(layer): + params = {} + params['num_capsule'] = layer.num_capsule + params['dim_capsule'] = layer.dim_capsule + params['num_routing'] = layer.num_routing + return jsonLayer('CapsuleLayer', params, layer) + + # ********** Pooling Layers ********** def Pooling(layer): params = {}