From 2bb175eb9d70d6ad33a8d602ba90a0ef174a78c5 Mon Sep 17 00:00:00 2001 From: Piotr Kubacki <74492669+she3r@users.noreply.github.com> Date: Sun, 20 Nov 2022 16:01:59 +0100 Subject: [PATCH] fix atol/rtol cnf & turn off dkl https://github.com/stevenygd/PointFlow/issues/23 --- regressionFlow/models/cnf.py | 7 ++-- .../networks_regression_SDD_conditional.py | 42 ++++++++++--------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/regressionFlow/models/cnf.py b/regressionFlow/models/cnf.py index 6a9444d..92aee96 100644 --- a/regressionFlow/models/cnf.py +++ b/regressionFlow/models/cnf.py @@ -52,6 +52,7 @@ def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularizati self.solver_options = {} self.conditional = conditional + def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False): if logpx is None: _logpx = torch.zeros(*x.shape[:-1], 1).to(x) @@ -64,9 +65,9 @@ def forward(self, x, context=None, logpx=None, integration_times=None, reverse=F assert context is not None states = (x, _logpx, context.to(x)) # atol = [self.atol] * 3 - # rtol = [self.rtol] * 3 - atol = [self.atol] * 3 - rtol = [self.rtol] * 3 + # rtol = [self.rtol] 3 + atol = self.atol + rtol = self.rtol else: states = (x, _logpx) atol = [self.atol] * 2 diff --git a/regressionFlow/models/networks_regression_SDD_conditional.py b/regressionFlow/models/networks_regression_SDD_conditional.py index 285b737..ef33740 100644 --- a/regressionFlow/models/networks_regression_SDD_conditional.py +++ b/regressionFlow/models/networks_regression_SDD_conditional.py @@ -148,26 +148,28 @@ def forward(self, x: torch.tensor): # x.shape = 5,65 = 5,64 + bias # 3) przerzuc przez flow -> w_i := F_{\theta}(z_i) tu z: embedding rozmiaru 100 target_networks_weights = self.point_cnf(y.reshape(1,-1), z.reshape(1,-1), reverse=True).view(*y.size()) - # Liczenie DKL. Pytanie- na spotkaniu co rozumielismy przez batche? W konktescie TN to batchem - # oznaczymy chyba zbior wag na kanaly wyjscia, ktorych jest piec, kazdy z nich bierze 64 wartosci + bias - y2, delta_log_py = self.point_cnf(target_networks_weights, z, - torch.zeros(batch_size, y.size(1), 1).to(y)) - - log_py = standard_normal_logprob(y2).view(batch_size, -1).sum(1, keepdim=True) - delta_log_py = delta_log_py.view(batch_size, y2.size(1), 1).sum(1) - log_px = log_py - delta_log_py - # policzyc gestosci flowa log p_0(F^{-1}_\theta(w_i) + J - loss = log_px.mean() - # policzyc gestosci priora log N(w_i | (0,I)) ale teraz inaczej niz ostatnio. Liczymy kazdy z pieciu wierszy osobno - # na rozkladzie normalnym wymiaru 65 (64 + bias poki co) - multivariate_normal_distrib = torch.distributions.MultivariateNormal( - torch.zeros(tn_num_values).to(loss), torch.eye(tn_num_values).to(loss)) - # todo sprawdzic czy to zadziala, tzn czy cala macierz mozna tak wrzucic, - # mamy odwzorowanie R^65 -> R na macierz R^5 X R^65 - # tutaj biore srednia po 5 zestawach N_65 - loss_density = multivariate_normal_distrib.log_prob(target_networks_weights).mean() - loss = loss - loss_density - + # ------- LOSS ---------- + # # Liczenie DKL. Pytanie- na spotkaniu co rozumielismy przez batche? W konktescie TN to batchem + # # oznaczymy chyba zbior wag na kanaly wyjscia, ktorych jest piec, kazdy z nich bierze 64 wartosci + bias + # y2, delta_log_py = self.point_cnf(target_networks_weights, z, + # torch.zeros(batch_size, y.size(1), 1).to(y)) + # + # log_py = standard_normal_logprob(y2).view(batch_size, -1).sum(1, keepdim=True) + # delta_log_py = delta_log_py.view(batch_size, y2.size(1), 1).sum(1) + # log_px = log_py - delta_log_py + # # policzyc gestosci flowa log p_0(F^{-1}_\theta(w_i) + J + # loss = log_px.mean() + # # policzyc gestosci priora log N(w_i | (0,I)) ale teraz inaczej niz ostatnio. Liczymy kazdy z pieciu wierszy osobno + # # na rozkladzie normalnym wymiaru 65 (64 + bias poki co) + # multivariate_normal_distrib = torch.distributions.MultivariateNormal( + # torch.zeros(tn_num_values).to(loss), torch.eye(tn_num_values).to(loss)) + # # todo sprawdzic czy to zadziala, tzn czy cala macierz mozna tak wrzucic, + # # mamy odwzorowanie R^65 -> R na macierz R^5 X R^65 + # # tutaj biore srednia po 5 zestawach N_65 + # loss_density = multivariate_normal_distrib.log_prob(target_networks_weights).mean() + # loss = loss - loss_density + loss = torch.tensor([0]) + target_networks_weights = target_networks_weights.reshape(5,65) return target_networks_weights, loss