Skip to content

Commit

Permalink
fix atol/rtol cnf & turn off dkl
Browse files Browse the repository at this point in the history
  • Loading branch information
she3r committed Nov 20, 2022
1 parent 77308d9 commit 2bb175e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
7 changes: 4 additions & 3 deletions regressionFlow/models/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularizati
self.solver_options = {}
self.conditional = conditional


def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
if logpx is None:
_logpx = torch.zeros(*x.shape[:-1], 1).to(x)
Expand All @@ -64,9 +65,9 @@ def forward(self, x, context=None, logpx=None, integration_times=None, reverse=F
assert context is not None
states = (x, _logpx, context.to(x))
# atol = [self.atol] * 3
# rtol = [self.rtol] * 3
atol = [self.atol] * 3
rtol = [self.rtol] * 3
# rtol = [self.rtol] 3
atol = self.atol
rtol = self.rtol
else:
states = (x, _logpx)
atol = [self.atol] * 2
Expand Down
42 changes: 22 additions & 20 deletions regressionFlow/models/networks_regression_SDD_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,26 +148,28 @@ def forward(self, x: torch.tensor): # x.shape = 5,65 = 5,64 + bias
# 3) przerzuc przez flow -> w_i := F_{\theta}(z_i) tu z: embedding rozmiaru 100
target_networks_weights = self.point_cnf(y.reshape(1,-1), z.reshape(1,-1), reverse=True).view(*y.size())

# Liczenie DKL. Pytanie- na spotkaniu co rozumielismy przez batche? W konktescie TN to batchem
# oznaczymy chyba zbior wag na kanaly wyjscia, ktorych jest piec, kazdy z nich bierze 64 wartosci + bias
y2, delta_log_py = self.point_cnf(target_networks_weights, z,
torch.zeros(batch_size, y.size(1), 1).to(y))

log_py = standard_normal_logprob(y2).view(batch_size, -1).sum(1, keepdim=True)
delta_log_py = delta_log_py.view(batch_size, y2.size(1), 1).sum(1)
log_px = log_py - delta_log_py
# policzyc gestosci flowa log p_0(F^{-1}_\theta(w_i) + J
loss = log_px.mean()
# policzyc gestosci priora log N(w_i | (0,I)) ale teraz inaczej niz ostatnio. Liczymy kazdy z pieciu wierszy osobno
# na rozkladzie normalnym wymiaru 65 (64 + bias poki co)
multivariate_normal_distrib = torch.distributions.MultivariateNormal(
torch.zeros(tn_num_values).to(loss), torch.eye(tn_num_values).to(loss))
# todo sprawdzic czy to zadziala, tzn czy cala macierz mozna tak wrzucic,
# mamy odwzorowanie R^65 -> R na macierz R^5 X R^65
# tutaj biore srednia po 5 zestawach N_65
loss_density = multivariate_normal_distrib.log_prob(target_networks_weights).mean()
loss = loss - loss_density

# ------- LOSS ----------
# # Liczenie DKL. Pytanie- na spotkaniu co rozumielismy przez batche? W konktescie TN to batchem
# # oznaczymy chyba zbior wag na kanaly wyjscia, ktorych jest piec, kazdy z nich bierze 64 wartosci + bias
# y2, delta_log_py = self.point_cnf(target_networks_weights, z,
# torch.zeros(batch_size, y.size(1), 1).to(y))
#
# log_py = standard_normal_logprob(y2).view(batch_size, -1).sum(1, keepdim=True)
# delta_log_py = delta_log_py.view(batch_size, y2.size(1), 1).sum(1)
# log_px = log_py - delta_log_py
# # policzyc gestosci flowa log p_0(F^{-1}_\theta(w_i) + J
# loss = log_px.mean()
# # policzyc gestosci priora log N(w_i | (0,I)) ale teraz inaczej niz ostatnio. Liczymy kazdy z pieciu wierszy osobno
# # na rozkladzie normalnym wymiaru 65 (64 + bias poki co)
# multivariate_normal_distrib = torch.distributions.MultivariateNormal(
# torch.zeros(tn_num_values).to(loss), torch.eye(tn_num_values).to(loss))
# # todo sprawdzic czy to zadziala, tzn czy cala macierz mozna tak wrzucic,
# # mamy odwzorowanie R^65 -> R na macierz R^5 X R^65
# # tutaj biore srednia po 5 zestawach N_65
# loss_density = multivariate_normal_distrib.log_prob(target_networks_weights).mean()
# loss = loss - loss_density
loss = torch.tensor([0])
target_networks_weights = target_networks_weights.reshape(5,65)
return target_networks_weights, loss


Expand Down

0 comments on commit 2bb175e

Please sign in to comment.