Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I write some code about HDRNetCurves #24

Open
alexliyang opened this issue Feb 12, 2022 · 1 comment
Open

I write some code about HDRNetCurves #24

alexliyang opened this issue Feb 12, 2022 · 1 comment

Comments

@alexliyang
Copy link

I think use conv can replace google/hdrnet's ccm function block:
# Color space change
idtity = np.identity(nchans, dtype=np.float32) + np.random.randn(1).astype(np.float32)*1e-4
ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity)
with tf.name_scope('ccm'):
ccm_bias = tf.get_variable('ccm_bias', shape=[nchans,], dtype=tf.float32, initializer=tf.constant_initializer(0.0))

  guidemap = tf.matmul(tf.reshape(input_tensor, [-1, nchans]), ccm)
  guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add')

  guidemap = tf.reshape(guidemap, tf.shape(input_tensor))

so , the code like the following:

class GuideCurves(nn.Module):
def init(self,npts = 16):
super(GuideCurves, self).init()
self.guide_pts = npts
self.ccm = ConvBlock(3,3,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False)

    self.shifts = np.linspace(0, 1, self.guide_pts, endpoint=False, dtype=np.float32)
    self.shifts = self.shifts[np.newaxis, np.newaxis, np.newaxis, :]
    self.shifts = np.tile(self.shifts, (3, 1, 1, 1))
    self.shifts = nn.Parameter(data=torch.from_numpy(self.shifts))

    self.slopes = np.zeros([1, 3, 1, 1, self.guide_pts], dtype=np.float32)
    self.slopes[:, :, :, :, 0] = 1.0
    self.slopes = nn.Parameter(data=torch.from_numpy(self.slopes))  

    self.projection = ConvBlock(3,1,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False)
	
def forward(self, x):
    guidemap = self.ccm(x)
    guidemap = guidemap.unsqueeze(dim=4)
    guidemap = (self.slopes * F.relu(guidemap - self.shifts)).sum(dim=4)
    guidemap = self.projection(guidemap)
    guidemap = F.hardtanh(guidemap, min_val=0, max_val=1)
    #print('guidemap:',guidemap.shape)
    #guidemap = guidemap.squeeze(dim=1)	
    		
    return guidemap
@creotiv
Copy link
Owner

creotiv commented Feb 12, 2022

Curves works worse then Conv, but still it's a good thing. you can make a PR and i add it to repo. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants