Skip to content

Commit

Permalink
minor edits;
Browse files Browse the repository at this point in the history
planning to try a new formulation with QR nullspace projection now
  • Loading branch information
enzbus committed Sep 14, 2024
1 parent a7b5c2b commit 2eb3d93
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
2 changes: 1 addition & 1 deletion project_euromir/equilibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _cones_separation_matrix(zero, nonneg, second_order):

def hsde_ruiz_equilibration( # pylint: disable=too-many-arguments
matrix, b, c, dimensions, d=None, e=None, rho=1., sigma=1.,
eps_rows=1E-4, eps_cols=1E-4, max_iters=25):
eps_rows=1E-1, eps_cols=1E-1, max_iters=25):
"""Ruiz equilibration of problem matrices for the HSDE system.
:param matrix: Problem matrix.
Expand Down
73 changes: 59 additions & 14 deletions project_euromir/solver_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from project_euromir.direction_calculator import (
CGNewton, DenseNewton, DiagPreconditionedCGNewton,
ExactDiagPreconditionedCGNewton, LSMRLevenbergMarquardt,
LSQRLevenbergMarquardt, MinResQLPTest, WarmStartedCGNewton,
LSQRLevenbergMarquardt, MinResQLPTest, WarmStartedCGNewton, _densify,
nocedal_wright_termination)
from project_euromir.line_searcher import (BacktrackingLineSearcher,
LogSpaceLineSearcher,
Expand All @@ -42,7 +42,7 @@

logger = logging.getLogger(__name__)

QR_PRESOLVE = False
QR_PRESOLVE = True

def solve(matrix, b, c, zero, nonneg, soc=(),
# xy = None, # need to import logic for equilibration
Expand All @@ -67,15 +67,19 @@ def solve(matrix, b, c, zero, nonneg, soc=(),
equilibrate.hsde_ruiz_equilibration(
matrix, b, c, dimensions={
'zero': zero, 'nonneg': nonneg, 'second_order': soc},
max_iters=25)
max_iters=1000)#, eps_rows=1E-5, eps_cols=1E-5,)

if QR_PRESOLVE:
q, r = np.linalg.qr(
np.vstack([matrix_transf.todense(), c_transf.reshape((1, n))]))
matrix_transf = q[:-1].A
c_transf = q[-1].A1
# q, r = np.linalg.qr(
# np.vstack([matrix_transf.todense(), c_transf.reshape((1, n))]))
# matrix_transf = q[:-1].A
# c_transf = q[-1].A1
q, r = np.linalg.qr(matrix_transf.todense())
matrix_transf = q.A
c_transf = np.linalg.solve(r, c_transf)

sigma_qr = np.linalg.norm(
b_transf) / np.mean(np.linalg.norm(matrix_transf, axis=1))
b_transf) #/ np.mean(np.linalg.norm(matrix_transf, axis=1))
b_transf = b_transf/sigma_qr

workspace = create_workspace(m, n, zero)
Expand All @@ -100,9 +104,9 @@ def _local_hessian_x_nogap(x):
return hessian_x_nogap(
x, m=m, n=n, zero=zero, matrix=matrix_transf, b=b_transf)

def _local_hessian_y_nogap(x):
def _local_hessian_y_nogap(y):
return hessian_y_nogap(
x, m=m, n=n, zero=zero, matrix=matrix_transf)
y, m=m, n=n, zero=zero, matrix=matrix_transf)

def _local_residual(xy):
return residual(
Expand Down Expand Up @@ -137,7 +141,7 @@ def _local_derivative_residual(xy):
# direction_calculator = WarmStartedCGNewton(
# # warm start causes issues if null space changes b/w iterations
# hessian_function=_local_hessian,
# rtol_termination=lambda x, g: min(0.5, np.linalg.norm(g)**0.5),
# rtol_termination=lambFalseda x, g: min(0.5, np.linalg.norm(g)**0.5),
# max_cg_iters=None,
# minres=False,
# regularizer=1e-8, # it seems 1e-10 is best, but it's too sensitive to it :(
Expand Down Expand Up @@ -186,7 +190,7 @@ def _local_derivative_residual(xy):
# direction_calculator = LSMRLevenbergMarquardt(
# residual_function=_local_residual,
# derivative_residual_function=_local_derivative_residual,
# warm_start=True, # also doesn't work with warm start
# warm_start=True, # also doesn't work with warm start
# )

# direction_calculator = DenseNewton( #WarmStartedCGNewton(
Expand All @@ -201,6 +205,8 @@ def _local_derivative_residual(xy):
_start = time.time()
# extra_iters=5
# all_losses = []
# all_dirnorms = []
# all_dirnorms_times_steplen = []

for newton_iterations in range(1000):

Expand All @@ -222,16 +228,51 @@ def _local_derivative_residual(xy):
# logger.info('Converged in %d iterations.', newton_iterations)
# break

# dense_hess = _densify(_local_hessian(xy))
# dense_hessx_nogap = _densify(_local_hessian_x_nogap(xy[:n]))
# dense_hessy_nogap = _densify(_local_hessian_y_nogap(xy[n:]))
# eivals = np.linalg.eigh(dense_hess)[0]

# #diag_precond = np.diag(1./(np.diag(dense_hess)))
# #dense_hess_diag_precond = dense_hess @ diag_precond
# #eivals_diag_precond = np.linalg.eigh(dense_hess_diag_precond)[0]

# import matplotlib.pyplot as plt
# plt.plot(eivals, label='hess eivals')
# # plt.plot(eivals_diag_precond)
# plt.plot(np.diag(dense_hess), label='hess diag')
# diag_nogap = np.concatenate([np.diag(dense_hessx_nogap),np.diag(dense_hessy_nogap)])
# plt.plot(diag_nogap, label='hess diag nogap')

# gap = np.concatenate([c_transf, b_transf])
# approx_hess = np.diag(diag_nogap) + np.outer(gap, gap) + np.eye(n+m) * 1e-2
# eivals_approx = np.linalg.eigh(approx_hess)[0]
# plt.plot(eivals_approx, label='eivals approx hess')

# pinv_approx_hess = np.linalg.pinv(approx_hess)
# precond_hess = dense_hess @ pinv_approx_hess
# eivals_precond = np.linalg.eigh(precond_hess)[0]
# plt.plot(eivals_precond, label='eivals pinv precond hess')

# plt.legend()
# plt.show()
# breakpoint()

direction = direction_calculator.get_direction(
current_point=xy,
current_gradient=grad_xy)

logger.info('direction norm %.2e', np.linalg.norm(direction))
# all_dirnorms.append(np.linalg.norm(direction))
# oldxy = np.copy(xy)
# all_losses.append(loss_xy)

xy, loss_xy, grad_xy = \
line_searcher.get_next(current_point=xy,
current_loss=loss_xy,
current_gradient=grad_xy, direction=direction)

# all_losses.append(loss_xy)
# all_dirnorms_times_steplen.append(np.linalg.norm(xy-oldxy))

# import matplotlib.pyplot as plt
# iter_x = xy[:n]
Expand Down Expand Up @@ -268,8 +309,12 @@ def _local_derivative_residual(xy):
f'Solver did not converge in {newton_iterations} iterations.')

# import matplotlib.pyplot as plt
# plt.semilogy(all_losses)
# plt.semilogy(all_dirnorms)
# plt.semilogy(all_dirnorms_times_steplen)
# plt.semilogy(np.sqrt(all_losses))

# plt.show()
# breakpoint()

if loss_xy > np.finfo(float).eps:
raise NotImplementedError(
Expand Down

0 comments on commit 2eb3d93

Please sign in to comment.