diff --git a/pyomo/contrib/gdpopt/ldsda.py b/pyomo/contrib/gdpopt/ldsda.py index e15e2756afb..e4466822dcf 100644 --- a/pyomo/contrib/gdpopt/ldsda.py +++ b/pyomo/contrib/gdpopt/ldsda.py @@ -120,9 +120,7 @@ def _solve_gdp(self, model, config): add_transformed_boolean_variable_list(self.working_model_util_block) self._log_header(logger) - self.working_model_external_var_info_list = self.get_external_information( - self.working_model_util_block, config - ) + self.get_external_information(self.working_model_util_block, config) self.directions = self.get_directions(self.number_of_external_variables, config) self.best_direction = None self.current_point = config.starting_point @@ -133,7 +131,6 @@ def _solve_gdp(self, model, config): if not hasattr(self.working_model_util_block, 'BigM'): self.working_model_util_block.BigM = Suffix() - locally_optimal = False # Solve the initial point self.fix_disjunctions_with_external_var( self.working_model_util_block, self.current_point @@ -141,6 +138,7 @@ def _solve_gdp(self, model, config): _ = self._solve_rnGDP_subproblem(self.working_model, config, 'Initial point') # Main loop + locally_optimal = False while not locally_optimal: self.iteration += 1 if self.any_termination_criterion_met(config): @@ -296,15 +294,13 @@ def get_directions(self, dimension, config): directions.remove((0,) * dimension) return directions - def check_valid_neighbor(self, neighbor, external_var_info_list): + def check_valid_neighbor(self, neighbor): """Function that checks if a given neighbor is valid. Parameters ---------- neighbor : list - the neighbor - external_var_info_list : list - the list of the external variable information + the neighbor to be checked Returns ------- @@ -317,7 +313,7 @@ def check_valid_neighbor(self, neighbor, external_var_info_list): external_var_value >= external_var_info.LB and external_var_value <= external_var_info.UB for external_var_value, external_var_info in zip( - neighbor, external_var_info_list + neighbor, self.working_model_util_block.external_var_info_list ) ): return True @@ -340,9 +336,7 @@ def neighbor_search(self, model, config): self.best_direction = None for direction in self.directions: neighbor = tuple(map(sum, zip(self.current_point, direction))) - if self.check_valid_neighbor( - neighbor, self.working_model_util_block.external_var_info_list - ): + if self.check_valid_neighbor(neighbor): self.fix_disjunctions_with_external_var( self.working_model_util_block, neighbor ) @@ -370,16 +364,17 @@ def line_search(self, model, config): primal_improved = True while primal_improved: next_point = tuple(map(sum, zip(self.current_point, self.best_direction))) - if not self.check_valid_neighbor( - next_point, self.working_model_util_block.external_var_info_list - ): + if self.check_valid_neighbor(next_point): + self.fix_disjunctions_with_external_var( + self.working_model_util_block, next_point + ) + primal_improved = self._solve_rnGDP_subproblem( + model, config, 'Line search' + ) + if primal_improved: + self.current_point = next_point + else: break - self.fix_disjunctions_with_external_var( - self.working_model_util_block, next_point - ) - primal_improved = self._solve_rnGDP_subproblem(model, config, 'Line search') - if primal_improved: - self.current_point = next_point print("Line search finished.") def handle_subproblem_result(