diff --git a/convert_torch.py b/convert_torch.py index b639f73..2bbd528 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -149,6 +149,19 @@ def lua_recursive_model(module,seq): n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) lua_recursive_model(m,n) add_submodule(seq,n) + elif name == 'Tanh': + n = nn.Tanh() + add_submodule(seq, n) + elif name == 'MulConstant': + n = Lambda(lambda x: x * m.constant_scalar) + add_submodule(seq, n) + elif name == 'SpatialZeroPadding': + n = nn.ConstantPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b), 0) + add_submodule(seq, n) + elif name == 'InstanceNormalization': + n = m._instance_norm + copy_param(m, n) # check this + add_submodule(seq, n) elif name == 'TorchObject': print('Not Implement',name,real._typename) else: @@ -229,6 +242,14 @@ def lua_recursive_source(module): s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)] s += lua_recursive_source(m) s += [')'] + elif name == 'Tanh': + s += ['nn.Tanh()'] + elif name == 'MulConstant': + s += ['Lambda(lambda x: x*{}), # MulConstant'.format(m.constant_scalar)] + elif name == 'SpatialZeroPadding': + s += ['nn.ConstantPad2d({}, 0)'.format((m.pad_l, m.pad_r, m.pad_t, m.pad_b))] + elif name == 'InstanceNormalization': + s += ['nn.InstanceNorm2d({},{},{},{}), # InstanceNorm2d'.format(m._instance_norm.num_features, m.eps, m._instance_norm.momentum, m._instance_norm.affine)] else: s += '# ' + name + ' Not Implement,\n' s = map(lambda x: '\t{}'.format(x),s)