-
Notifications
You must be signed in to change notification settings - Fork 5
/
custom_layers.py
159 lines (137 loc) · 6.4 KB
/
custom_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.nn.init import kaiming_normal_, calculate_gain
from PIL import Image
import numpy as np
import copy
__author__ = 'Rahul Bhalley'
class ConcatTable(nn.Module):
'''Concatination of two layers into vector
'''
def __init__(self, layer1, layer2):
super(ConcatTable, self).__init__()
self.layer1 = layer1
self.layer2 = layer2
def forward(self, x):
return [self.layer1(x), self.layer2(x)]
class Flatten(nn.Module):
'''Flattens the convolution layer
'''
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class FadeInLayer(nn.Module):
'''The layer fades in to the network with `alpha` value slowing entering in to existence
'''
def __init__(self, config):
super(FadeInLayer, self).__init__()
self.alpha = 0.0
def update_alpha(self, delta):
self.alpha = self.alpha + delta
self.alpha = max(0, min(self.alpha, 1.0))
# input `x` to `forward()` is output from `ConcatTable()`
def forward(self, x):
# `x[0]` is `prev_block` output faded out of existence with 1.0 - `alpha`
# `x[1]` is `next_block` output faded in to existence with `alpha`
# This is becasue `alpha` increases linearly
# Both `x[0]` and `x[1]` outputs 3-dim tensor (last block is `to_rgb_block`)
# So `add()` can work effectively and produce one weighted output
return torch.add(x[0].mul(1.0 - self.alpha), x[1].mul(self.alpha)) # outputs one value
class MinibatchSTDConcatLayer(nn.Module):
'''
'''
def __init__(self, averaging='all'):
super(MinibatchSTDConcatLayer, self).__init__()
self.averaging = averaging.lower()
if 'group' in self.averaging:
self.n = int(self.averaging[5:])
else:
assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging
self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)
def forward(self, x):
shape = list(x.size())
target_shape = copy.deepcopy(shape)
vals = self.adjusted_std(x, dim=0, keepdim=True)
if self.averaging == 'all':
target_shape[1] = 1
vals = torch.mean(vals, dim=1, keepdim=True)
elif self.averaging == 'spatial':
if len(shape) == 4:
vals = torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True)
elif self.averaging == 'none':
target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
elif self.averaging == 'gpool':
if len(shape) == 4:
vals = torch.mean(torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True)
elif self.averaging == 'flat':
target_shape[1] = 1
vals = torch.FloatTensor([self.adjusted_std(x)])
else: # self.averaging == 'group'
target_shape[1] = self.n
vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3])
vals = torch.mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1)
vals = vals.expand(*target_shape)
return torch.cat([x, vals], 1)
def __repr__(self):
return self.__class__.__name__ + '(averaging = {})'.format(self.averaging)
class PixelwiseNormLayer(nn.Module):
'''
'''
def __init__(self):
super(PixelwiseNormLayer, self).__init__()
self.eps = 1e-8
def forward(self, x):
return x / (torch.mean(x ** 2, dim=1, keepdim=True) + self.eps) ** 0.5
class EqualizedConv2d(nn.Module):
'''Equalize the learning rate for convolotional layer
'''
def __init__(self, c_in, c_out, k_size, stride, pad, bias=False):
super(EqualizedConv2d, self).__init__()
self.conv = nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False)
kaiming_normal_(self.conv.weight, a=calculate_gain('conv2d'))
# Scaling the weights for equalized learning
conv_w = self.conv.weight.data.clone()
self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0))
self.scale = (torch.mean(self.conv.weight.data ** 2)) ** 0.5
self.conv.weight.data.copy_(self.conv.weight.data / self.scale) # for equalized learning rate
def forward(self, x):
x = self.conv(x.mul(self.scale))
return x + self.bias.view(1, -1, 1, 1).expand_as(x)
class EqualizedDeconv2d(nn.Module):
'''Equalize the learning rate for transpose convolotional layer
'''
def __init__(self, c_in, c_out, k_size, stride, pad):
super(EqualizedDeconv2d, self).__init__()
self.deconv = nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False)
kaiming_normal_(self.deconv.weight, a=calculate_gain('conv2d'))
# Scaling the weights for equalized learning
deconv_w = self.deconv.weight.data.clone()
self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0))
self.scale = (torch.mean(self.deconv.weight.data ** 2)) ** 0.5
self.deconv.weight.data.copy_(self.deconv.weight.data / self.scale)
def forward(self, x):
x = self.deconv(x.mul(self.scale))
return x + self.bias.view(1, -1, 1, 1).expand_as(x)
class EqualizedLinear(nn.Module):
'''Equalize the learning rate for linear layer
'''
def __init__(self, c_in, c_out):
super(EqualizedLinear, self).__init__()
self.linear = nn.Linear(c_in, c_out, bias=False)
kaiming_normal_(self.linear.weight, a=calculate_gain('linear'))
# Scaling the weights for equalized learning
linear_w = self.linear.weight.data.clone()
self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0))
self.scale = (torch.mean(self.linear.weight.data ** 2)) ** 0.5
self.linear.weight.data.copy_(self.linear.weight.data / self.scale)
def forward(self, x):
x = self.linear(x.mul(self.scale))
return x + self.bias.view(1, -1).expand_as(x)
# The GeneralizedDropout class has been removed as it's not commonly used in PyTorch 2.x
# If needed, you can implement a custom dropout layer using torch.nn.functional.dropout