From 191b902d213be2185b2567fd82b9ea45a512221d Mon Sep 17 00:00:00 2001 From: LH23 Date: Mon, 10 Jul 2023 20:04:36 -0300 Subject: [PATCH 1/4] Added tanh --- convert_torch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/convert_torch.py b/convert_torch.py index b639f73..96e672a 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -149,6 +149,9 @@ def lua_recursive_model(module,seq): n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) lua_recursive_model(m,n) add_submodule(seq,n) + elif name == 'Tanh': + n = nn.Tanh() + add_submodule(seq, n) elif name == 'TorchObject': print('Not Implement',name,real._typename) else: @@ -229,6 +232,8 @@ def lua_recursive_source(module): s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)] s += lua_recursive_source(m) s += [')'] + elif name == 'Tanh': + s += ['nn.Tanh()'] else: s += '# ' + name + ' Not Implement,\n' s = map(lambda x: '\t{}'.format(x),s) From 39fdc1b1ee28ec12a4ab1e797d4d596586283529 Mon Sep 17 00:00:00 2001 From: LH23 Date: Mon, 10 Jul 2023 20:05:07 -0300 Subject: [PATCH 2/4] Added MulConstant --- convert_torch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/convert_torch.py b/convert_torch.py index 96e672a..f080e75 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -152,6 +152,9 @@ def lua_recursive_model(module,seq): elif name == 'Tanh': n = nn.Tanh() add_submodule(seq, n) + elif name == 'MulConstant': + n = Lambda(lambda x: x * m.constant_scalar) + add_submodule(seq, n) elif name == 'TorchObject': print('Not Implement',name,real._typename) else: @@ -234,6 +237,8 @@ def lua_recursive_source(module): s += [')'] elif name == 'Tanh': s += ['nn.Tanh()'] + elif name == 'MulConstant': + s += ['Lambda(lambda x: x*{}), # MulConstant'.format(m.constant_scalar)] else: s += '# ' + name + ' Not Implement,\n' s = map(lambda x: '\t{}'.format(x),s) From eda8861e9a74611aabce820289026997b2b083d2 Mon Sep 17 00:00:00 2001 From: LH23 Date: Mon, 10 Jul 2023 20:05:41 -0300 Subject: [PATCH 3/4] Added SpatialZeroPadding --- convert_torch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/convert_torch.py b/convert_torch.py index f080e75..be519a1 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -155,6 +155,9 @@ def lua_recursive_model(module,seq): elif name == 'MulConstant': n = Lambda(lambda x: x * m.constant_scalar) add_submodule(seq, n) + elif name == 'SpatialZeroPadding': + n = nn.ConstantPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b), 0) + add_submodule(seq, n) elif name == 'TorchObject': print('Not Implement',name,real._typename) else: @@ -239,6 +242,8 @@ def lua_recursive_source(module): s += ['nn.Tanh()'] elif name == 'MulConstant': s += ['Lambda(lambda x: x*{}), # MulConstant'.format(m.constant_scalar)] + elif name == 'SpatialZeroPadding': + s += ['nn.ConstantPad2d({}, 0)'.format((m.pad_l, m.pad_r, m.pad_t, m.pad_b))] else: s += '# ' + name + ' Not Implement,\n' s = map(lambda x: '\t{}'.format(x),s) From d8e88a70ab318e7f0de27fc198625c3aab85255a Mon Sep 17 00:00:00 2001 From: LH23 Date: Mon, 10 Jul 2023 20:06:57 -0300 Subject: [PATCH 4/4] Added InstanceNormalizarion (2d) --- convert_torch.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/convert_torch.py b/convert_torch.py index be519a1..2bbd528 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -158,6 +158,10 @@ def lua_recursive_model(module,seq): elif name == 'SpatialZeroPadding': n = nn.ConstantPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b), 0) add_submodule(seq, n) + elif name == 'InstanceNormalization': + n = m._instance_norm + copy_param(m, n) # check this + add_submodule(seq, n) elif name == 'TorchObject': print('Not Implement',name,real._typename) else: @@ -244,6 +248,8 @@ def lua_recursive_source(module): s += ['Lambda(lambda x: x*{}), # MulConstant'.format(m.constant_scalar)] elif name == 'SpatialZeroPadding': s += ['nn.ConstantPad2d({}, 0)'.format((m.pad_l, m.pad_r, m.pad_t, m.pad_b))] + elif name == 'InstanceNormalization': + s += ['nn.InstanceNorm2d({},{},{},{}), # InstanceNorm2d'.format(m._instance_norm.num_features, m.eps, m._instance_norm.momentum, m._instance_norm.affine)] else: s += '# ' + name + ' Not Implement,\n' s = map(lambda x: '\t{}'.format(x),s)