forked from brain-bzh/efficient-deep-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
binaryconnect.py
86 lines (56 loc) · 2.72 KB
/
binaryconnect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
### See http://papers.nips.cc/paper/5647-binaryconnect-training-deep-neural-networks-with-b
### for a complete description of the algotihm
#
import torch.nn as nn
import numpy
class BC():
def __init__(self, model):
# First we need to
# count the number of Conv2d and Linear
# This will be used next in order to build a list of all
# parameters of the model
count_targets = 0
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
count_targets = count_targets + 1
start_range = 0
end_range = count_targets-1
self.bin_range = numpy.linspace(start_range,
end_range, end_range-start_range+1)\
.astype('int').tolist()
# Now we can initialize the list of parameters
self.num_of_params = len(self.bin_range)
self.saved_params = [] # This will be used to save the full precision weights
self.target_modules = [] # this will contain the list of modules to be modified
self.model = model # this contains the model that will be trained and quantified
### This builds the initial copy of all parameters and target modules
index = -1
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
index = index + 1
if index in self.bin_range:
tmp = m.weight.data.clone()
self.saved_params.append(tmp)
self.target_modules.append(m.weight)
def save_params(self):
### This loop goes through the list of target modules, and saves the corresponding weights into the list of saved_parameters
for index in range(self.num_of_params):
self.saved_params[index].copy_(self.target_modules[index].data)
def binarization(self):
### To be completed
### (1) Save the current full precision parameters using the save_params method
1
### (2) Binarize the weights in the model, by iterating through the list of target modules and overwrite the values with their binary version
def restore(self):
### restore the copy from self.saved_params into the model
for index in range(self.num_of_params):
self.target_modules[index].data.copy_(self.saved_params[index])
def clip(self):
## To be completed
## Clip all parameters to the range [-1,1] using Hard Tanh
## you can use the nn.Hardtanh function
1
def forward(self,x):
### This function is used so that the model can be used while training
out = self.model(x)
return out