From a673b35b2d58bce02c139542cb207a5f5b89cde4 Mon Sep 17 00:00:00 2001 From: Anatoly Baksheev Date: Fri, 6 Jul 2018 13:40:39 +0300 Subject: [PATCH] formatting + use nn.MaxPool2d if no SpatialMaxUnpooling --- convert_torch.py | 93 ++++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 34 deletions(-) diff --git a/convert_torch.py b/convert_torch.py index 53d49a8..dd0a687 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -36,11 +36,29 @@ def add_submodule(seq, *args): class Convertor(object): - def __init__(self): + def __init__(self, model): self.prefix_code = [] self.t2pt_names = dict() self.t2pt_layers = dict() + def search_max_unpool(model): + modules = [] + modules.extend(model.modules) + containers = ['Sequential', 'Concat'] + + while modules: + m = modules.pop() + name = type(m).__name__ + if name in containers: + modules.extend(m.modules) + + if name == 'SpatialMaxUnpooling': + return True + + return False + + self.have_max_unpool = search_max_unpool(model) + def lua_recursive_model(self, module, seq): for m in module.modules: name = type(m).__name__ @@ -69,9 +87,11 @@ def lua_recursive_model(self, module, seq): n = nn.Sigmoid() add_submodule(seq, n) elif name == 'SpatialMaxPooling': - # n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode) - n = StatefulMaxPool2d((m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), ceil_mode=m.ceil_mode) - self.t2pt_layers[m] = n + if not self.have_max_unpool: + n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode) + else: + n = StatefulMaxPool2d((m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), ceil_mode=m.ceil_mode) + self.t2pt_layers[m] = n add_submodule(seq, n) elif name == 'SpatialMaxUnpooling': if m.pooling in self.t2pt_layers: @@ -164,10 +184,11 @@ def lua_recursive_source(self, module): if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM': if not hasattr(m, 'groups') or m.groups is None: m.groups = 1 - s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, - m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, m.bias is not None)] + s += ['nn.Conv2d({}, {}, {}, {}, {}, {}, {},bias={}), #Conv2d'.format(m.nInputPlane, + m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, + m.bias is not None)] elif name == 'SpatialBatchNormalization': - s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] + s += ['nn.BatchNorm2d({}, {}, {}, {}), #BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] elif name == 'VolumetricBatchNormalization': s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] elif name == 'ReLU': @@ -175,19 +196,21 @@ def lua_recursive_source(self, module): elif name == 'Sigmoid': s += ['nn.Sigmoid()'] elif name == 'SpatialMaxPooling': - # s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)] - suffixes = sorted(int(re.match('pooling_(\d*)', v).group(1)) for v in self.t2pt_names.values()) - name = 'pooling_{}'.format(suffixes[-1] + 1 if suffixes else 1) - s += [name] - self.t2pt_names[m] = name - self.prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})'.format(name, (m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), m.ceil_mode)] + if not self.have_max_unpool: + s += ['nn.MaxPool2d({}, {}, {}, ceil_mode={}), #MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)] + else: + suffixes = sorted(int(re.match('pooling_(\d*)', v).group(1)) for v in self.t2pt_names.values()) + name = 'pooling_{}'.format(suffixes[-1] + 1 if suffixes else 1) + s += [name] + self.t2pt_names[m] = name + self.prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})'.format(name, (m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), m.ceil_mode)] elif name == 'SpatialMaxUnpooling': if m.pooling in self.t2pt_names: s += ['StatefulMaxUnpool2d({}), #SpatialMaxUnpooling'.format(self.t2pt_names[m.pooling])] else: s += ['# ' + name + ' Not Implement (can\'t find corresponding SpatialMaxUnpooling,\n'] elif name == 'SpatialAveragePooling': - s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)] + s += ['nn.AvgPool2d({}, {}, {}, ceil_mode={}), #AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)] elif name == 'SpatialUpSamplingNearest': s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)] elif name == 'View': @@ -197,7 +220,7 @@ def lua_recursive_source(self, module): elif name == 'Linear': s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )' s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1), m.weight.size(0), (m.bias is not None)) - s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)] + s += ['nn.Sequential({}, {}), #Linear'.format(s1, s2)] elif name == 'Dropout': s += ['nn.Dropout({})'.format(m.p)] elif name == 'SoftMax': @@ -245,20 +268,20 @@ def lua_recursive_source(self, module): @staticmethod def simplify_source(s): - s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'), s) - s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s) - s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s) - s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s) - s = map(lambda x: x.replace('),#Conv2d', ')'), s) - s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s) - s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s) - s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s) - s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s) - s = map(lambda x: x.replace('),#MaxPool2d', ')'), s) - s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s) - s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s) - s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s) - s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s) + s = map(lambda x: x.replace(', (1, 1), (0, 0), 1, 1, bias=True), #Conv2d', ')'), s) + s = map(lambda x: x.replace(', (0, 0), 1, 1, bias=True), #Conv2d', ')'), s) + s = map(lambda x: x.replace(', 1, 1, bias=True), #Conv2d', ')'), s) + s = map(lambda x: x.replace(', bias=True), #Conv2d', ')'), s) + s = map(lambda x: x.replace('), #Conv2d', ')'), s) + s = map(lambda x: x.replace(', 1e-05, 0.1, True), #BatchNorm2d', ')'), s) + s = map(lambda x: x.replace('), #BatchNorm2d', ')'), s) + s = map(lambda x: x.replace(', (0, 0), ceil_mode=False), #MaxPool2d', ')'), s) + s = map(lambda x: x.replace(', ceil_mode=False), #MaxPool2d', ')'), s) + s = map(lambda x: x.replace('), #MaxPool2d', ')'), s) + s = map(lambda x: x.replace(', (0, 0), ceil_mode=False), #AvgPool2d', ')'), s) + s = map(lambda x: x.replace(', ceil_mode=False), #AvgPool2d', ')'), s) + s = map(lambda x: x.replace(', bias=True)), #Linear', ')), # Linear'), s) + s = map(lambda x: x.replace(')), #Linear', ')), # Linear'), s) s = map(lambda x: '{},\n'.format(x), s) s = map(lambda x: x[1:], s) @@ -272,9 +295,9 @@ def torch_to_pytorch(t7_filename, outputname=None): model = model.model model.gradInput = None - cvt = Convertor() - slist = cvt.lua_recursive_source(lnn.Sequential().add(model)) - s = cvt.simplify_source(slist) + cvt = Convertor(model) + s = cvt.lua_recursive_source(lnn.Sequential().add(model)) + s = cvt.simplify_source(s) varname = os.path.basename(t7_filename).replace('.t7', '').replace('.', '_').replace('-', '_') @@ -282,7 +305,9 @@ def torch_to_pytorch(t7_filename, outputname=None): header = f.read() s = '{}\n{}\n\n{} = {}'.format(header, '\n'.join(cvt.prefix_code), varname, s[:-2]) - if outputname is None: outputname = varname + if outputname is None: + outputname = varname + with open(outputname + '.py', "w") as pyfile: pyfile.write(s) @@ -294,7 +319,7 @@ def torch_to_pytorch(t7_filename, outputname=None): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch') parser.add_argument('--model', '-m', type=str, required=True, help='torch model file in t7 format') - parser.add_argument('--output', '-o', type=str, default=None, help='output file name prefix, xxx.py xxx.pth') + parser.add_argument('--output', '-o', type=str, default='/tmp/model', help='output file name prefix, xxx.py xxx.pth') args = parser.parse_args() torch_to_pytorch(args.model, args.output)