-
Notifications
You must be signed in to change notification settings - Fork 4
/
maxout.py
139 lines (119 loc) · 4.71 KB
/
maxout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#/usr/bin/python3
# -*- coding: utf-8 -*-
# Library modules
import json
# External library modules
import torch
# Local Library modules
from utils import init_hyper_params
from utils import num_corrects
from utils import device
class MaxoutMLP(torch.nn.Module):
"""Maxout using Multilayer Perceptron"""
def __init__(self, input_size,
linear_layers, linear_neurons):
"""
Define layers of maxout unit
:param input_size: number of values(pixels or hidden unit's output)
that will be inputted to the layer
:type input_size: :py:obj:`int`
:param linear_layers: number of linear layers before
max operation
:type linear_layers: :py:obj:`int`
:param linear_neurons: number of neurons in each linear
layer before max operation
:type linear_neurons: :py:obj:`int`
:Example:
>>> from time import time
>>> import torch
>>> def fun(times):
... s = time()
... for i in range(times):
... torch.nn.Linear(784, 2048)(torch.randn(10, 784))
... print(time() - s)
...
>>> fun(100)
1.0891399383544922
>>> def fun1(times):
... s = time()
... torch.nn.Linear(784, 2048 * times)(torch.randn(10, 784))
... print(time() - s)
...
>>> fun1(100)
1.425891399383545
"""
super(MaxoutMLP, self).__init__()
# initialize variables
self.input_size = input_size
self.linear_layers = linear_layers
self.linear_neurons = linear_neurons
# batch normalization layer
self.BN = torch.nn.BatchNorm1d(self.linear_neurons)
# pytorch not able to reach the parameters of
# linear layer inside a list
self.params = torch.nn.ParameterList()
self.z = []
for layer in range(self.linear_layers):
self.z.append(torch.nn.Linear(self.input_size,
self.linear_neurons))
self.params.extend(list(self.z[layer].parameters()))
def forward(self, input_, is_norm=False, **kargs):
"""
Function to forward inputs to maxout layer
:param input_: input to the maxout layer
:type input_: :py:class:`torch.Tensor`
:param is_norm: whether to perform normalization before max
operation
:type is_norm: :py:obj:`bool`
:param kargs: keyword arguments containing
1. maximum value to allow to the next layer after normalization
2. :py:func:`torch.empty` containing the value norm constraint. It's size is same
as or broadcastable to the weight where norm constraint is performed.
:type kargs: :py:obj:`dict`
"""
h = None
for layer in range(self.linear_layers):
z = self.z[layer](input_)
# norm + norm constraint
if is_norm:
z = self.BN(z)
z = torch.where(z <= kargs['norm_constraint'],
z, kargs['norm_upper'])
if layer == 0:
h = z
else:
h = torch.max(h, z)
return h
class MaxoutConv(torch.nn.Module):
"""Maxout layer with convolution"""
def __init__(self, in_channels,
out_channels, kernel_size, padding):
"""
Define layers of maxout unit
:param in_channels: number of channel of input convolution
:type in_channels: :py:obj:`int`
:param out_channels: number of channel of output convolution
:type out_channels: :py:obj:`int`
:param kernel_size: size of the weight matrix to convolve
:type kernel_size: (:py:obj:`int`, :py:obj:`int`)
"""
super(MaxoutConv, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
padding=padding)
self.BN = torch.nn.BatchNorm2d(out_channels)
def forward(self, _input, is_norm=False):
"""
Pass the input to the maxout layer
:param _input: input to the maxout layer
input is expected to have channel dimension
:type _input: :py:class:`torch.Tensor`
"""
z = self.conv(_input)
if is_norm:
z = self.BN(z)
# (batch size, channels, height, width)
h = torch.max(z, 1).values # take max operation from first dimension(channel)
# Insert 1 as channel dimension to h
hshape = h.shape
h = h.reshape(*([hshape[0]] + [1] + list(hshape[1:])))
return h