From 13638ab03003483252b8798a31da4177ed9c6cc9 Mon Sep 17 00:00:00 2001 From: Yury Adamov Date: Sat, 6 Oct 2018 01:20:34 +0300 Subject: [PATCH] Added SpatialDilatedConvolution handling --- convert_torch.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/convert_torch.py b/convert_torch.py index b639f73..479ce2f 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -149,6 +149,11 @@ def lua_recursive_model(module,seq): n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) lua_recursive_model(m,n) add_submodule(seq,n) + elif name == 'SpatialDilatedConvolution': + if not hasattr(m,'groups') or m.groups is None: m.groups=1 + n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.dilationW,m.dilationH),m.groups,bias=(m.bias is not None)) + copy_param(m,n) + add_submodule(seq,n) elif name == 'TorchObject': print('Not Implement',name,real._typename) else: @@ -168,6 +173,10 @@ def lua_recursive_source(module): if not hasattr(m,'groups') or m.groups is None: m.groups=1 s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)] + elif name == 'SpatialDilatedConvolution': + if not hasattr(m,'groups') or m.groups is None: m.groups=1 + s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, + m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.dilationW,m.dilationH),m.groups,m.bias is not None)] elif name == 'SpatialBatchNormalization': s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] elif name == 'VolumetricBatchNormalization':