Skip to content

Commit

Permalink
formatting + use nn.MaxPool2d if no SpatialMaxUnpooling
Browse files Browse the repository at this point in the history
  • Loading branch information
Anatoly Baksheev committed Jul 6, 2018
1 parent e77709f commit a673b35
Showing 1 changed file with 59 additions and 34 deletions.
93 changes: 59 additions & 34 deletions convert_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,29 @@ def add_submodule(seq, *args):

class Convertor(object):

def __init__(self):
def __init__(self, model):
self.prefix_code = []
self.t2pt_names = dict()
self.t2pt_layers = dict()

def search_max_unpool(model):
modules = []
modules.extend(model.modules)
containers = ['Sequential', 'Concat']

while modules:
m = modules.pop()
name = type(m).__name__
if name in containers:
modules.extend(m.modules)

if name == 'SpatialMaxUnpooling':
return True

return False

self.have_max_unpool = search_max_unpool(model)

def lua_recursive_model(self, module, seq):
for m in module.modules:
name = type(m).__name__
Expand Down Expand Up @@ -69,9 +87,11 @@ def lua_recursive_model(self, module, seq):
n = nn.Sigmoid()
add_submodule(seq, n)
elif name == 'SpatialMaxPooling':
# n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
n = StatefulMaxPool2d((m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), ceil_mode=m.ceil_mode)
self.t2pt_layers[m] = n
if not self.have_max_unpool:
n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
else:
n = StatefulMaxPool2d((m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), ceil_mode=m.ceil_mode)
self.t2pt_layers[m] = n
add_submodule(seq, n)
elif name == 'SpatialMaxUnpooling':
if m.pooling in self.t2pt_layers:
Expand Down Expand Up @@ -164,30 +184,33 @@ def lua_recursive_source(self, module):

if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
if not hasattr(m, 'groups') or m.groups is None: m.groups = 1
s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, m.bias is not None)]
s += ['nn.Conv2d({}, {}, {}, {}, {}, {}, {},bias={}), #Conv2d'.format(m.nInputPlane,
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups,
m.bias is not None)]
elif name == 'SpatialBatchNormalization':
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
s += ['nn.BatchNorm2d({}, {}, {}, {}), #BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
elif name == 'VolumetricBatchNormalization':
s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
elif name == 'ReLU':
s += ['nn.ReLU()']
elif name == 'Sigmoid':
s += ['nn.Sigmoid()']
elif name == 'SpatialMaxPooling':
# s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
suffixes = sorted(int(re.match('pooling_(\d*)', v).group(1)) for v in self.t2pt_names.values())
name = 'pooling_{}'.format(suffixes[-1] + 1 if suffixes else 1)
s += [name]
self.t2pt_names[m] = name
self.prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})'.format(name, (m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), m.ceil_mode)]
if not self.have_max_unpool:
s += ['nn.MaxPool2d({}, {}, {}, ceil_mode={}), #MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
else:
suffixes = sorted(int(re.match('pooling_(\d*)', v).group(1)) for v in self.t2pt_names.values())
name = 'pooling_{}'.format(suffixes[-1] + 1 if suffixes else 1)
s += [name]
self.t2pt_names[m] = name
self.prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})'.format(name, (m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), m.ceil_mode)]
elif name == 'SpatialMaxUnpooling':
if m.pooling in self.t2pt_names:
s += ['StatefulMaxUnpool2d({}), #SpatialMaxUnpooling'.format(self.t2pt_names[m.pooling])]
else:
s += ['# ' + name + ' Not Implement (can\'t find corresponding SpatialMaxUnpooling,\n']
elif name == 'SpatialAveragePooling':
s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
s += ['nn.AvgPool2d({}, {}, {}, ceil_mode={}), #AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
elif name == 'SpatialUpSamplingNearest':
s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
elif name == 'View':
Expand All @@ -197,7 +220,7 @@ def lua_recursive_source(self, module):
elif name == 'Linear':
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1), m.weight.size(0), (m.bias is not None))
s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)]
s += ['nn.Sequential({}, {}), #Linear'.format(s1, s2)]
elif name == 'Dropout':
s += ['nn.Dropout({})'.format(m.p)]
elif name == 'SoftMax':
Expand Down Expand Up @@ -245,20 +268,20 @@ def lua_recursive_source(self, module):

@staticmethod
def simplify_source(s):
s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'), s)
s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s)
s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s)
s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s)
s = map(lambda x: x.replace('),#Conv2d', ')'), s)
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s)
s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s)
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s)
s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s)
s = map(lambda x: x.replace('),#MaxPool2d', ')'), s)
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s)
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s)
s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s)
s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s)
s = map(lambda x: x.replace(', (1, 1), (0, 0), 1, 1, bias=True), #Conv2d', ')'), s)
s = map(lambda x: x.replace(', (0, 0), 1, 1, bias=True), #Conv2d', ')'), s)
s = map(lambda x: x.replace(', 1, 1, bias=True), #Conv2d', ')'), s)
s = map(lambda x: x.replace(', bias=True), #Conv2d', ')'), s)
s = map(lambda x: x.replace('), #Conv2d', ')'), s)
s = map(lambda x: x.replace(', 1e-05, 0.1, True), #BatchNorm2d', ')'), s)
s = map(lambda x: x.replace('), #BatchNorm2d', ')'), s)
s = map(lambda x: x.replace(', (0, 0), ceil_mode=False), #MaxPool2d', ')'), s)
s = map(lambda x: x.replace(', ceil_mode=False), #MaxPool2d', ')'), s)
s = map(lambda x: x.replace('), #MaxPool2d', ')'), s)
s = map(lambda x: x.replace(', (0, 0), ceil_mode=False), #AvgPool2d', ')'), s)
s = map(lambda x: x.replace(', ceil_mode=False), #AvgPool2d', ')'), s)
s = map(lambda x: x.replace(', bias=True)), #Linear', ')), # Linear'), s)
s = map(lambda x: x.replace(')), #Linear', ')), # Linear'), s)

s = map(lambda x: '{},\n'.format(x), s)
s = map(lambda x: x[1:], s)
Expand All @@ -272,17 +295,19 @@ def torch_to_pytorch(t7_filename, outputname=None):
model = model.model
model.gradInput = None

cvt = Convertor()
slist = cvt.lua_recursive_source(lnn.Sequential().add(model))
s = cvt.simplify_source(slist)
cvt = Convertor(model)
s = cvt.lua_recursive_source(lnn.Sequential().add(model))
s = cvt.simplify_source(s)

varname = os.path.basename(t7_filename).replace('.t7', '').replace('.', '_').replace('-', '_')

with open("header.py") as f:
header = f.read()
s = '{}\n{}\n\n{} = {}'.format(header, '\n'.join(cvt.prefix_code), varname, s[:-2])

if outputname is None: outputname = varname
if outputname is None:
outputname = varname

with open(outputname + '.py', "w") as pyfile:
pyfile.write(s)

Expand All @@ -294,7 +319,7 @@ def torch_to_pytorch(t7_filename, outputname=None):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
parser.add_argument('--model', '-m', type=str, required=True, help='torch model file in t7 format')
parser.add_argument('--output', '-o', type=str, default=None, help='output file name prefix, xxx.py xxx.pth')
parser.add_argument('--output', '-o', type=str, default='/tmp/model', help='output file name prefix, xxx.py xxx.pth')
args = parser.parse_args()

torch_to_pytorch(args.model, args.output)

0 comments on commit a673b35

Please sign in to comment.