diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index b519b3a0ae..7ab7893cb1 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -369,7 +369,6 @@ static std::shared_ptr createHybridGaussianFactor( static std::pair> hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, - const KeyVector &continuousSeparator, const std::set &discreteSeparatorSet) { // NOTE: since we use the special JunctionTree, // only possibility is continuous conditioned on discrete. @@ -386,13 +385,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors, factorGraphTree = removeEmpty(factorGraphTree); // This is the elimination method on the leaf nodes + bool someContinuousLeft = false; auto eliminate = [&](const GaussianFactorGraph &graph) -> Result { if (graph.empty()) { return {nullptr, nullptr}; } + // Expensive elimination of product factor. auto result = EliminatePreferCholesky(graph, frontalKeys); + // Record whether there any continuous variables left + someContinuousLeft |= !result.second->empty(); + return result; }; @@ -403,9 +407,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // error for each discrete choice. Otherwise, create a HybridGaussianFactor // on the separator, taking care to correct for conditional constants. auto newFactor = - continuousSeparator.empty() - ? createDiscreteFactor(eliminationResults, discreteSeparator) - : createHybridGaussianFactor(eliminationResults, discreteSeparator); + someContinuousLeft + ? createHybridGaussianFactor(eliminationResults, discreteSeparator) + : createDiscreteFactor(eliminationResults, discreteSeparator); // Create the HybridGaussianConditional from the conditionals HybridGaussianConditional::Conditionals conditionals( @@ -514,22 +518,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // Case 3: We are now in the hybrid land! KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); - // Find all the keys in the set of continuous keys - // which are not in the frontal keys. This is our continuous separator. - KeyVector continuousSeparator; - auto continuousKeySet = factors.continuousKeySet(); - std::set_difference( - continuousKeySet.begin(), continuousKeySet.end(), - frontalKeysSet.begin(), frontalKeysSet.end(), - std::inserter(continuousSeparator, continuousSeparator.begin())); - - // Similarly for the discrete separator. + // Find all discrete keys. // Since we eliminate all continuous variables first, // the discrete separator will be *all* the discrete keys. std::set discreteSeparator = factors.discreteKeys(); - return hybridElimination(factors, frontalKeys, continuousSeparator, - discreteSeparator); + return hybridElimination(factors, frontalKeys, discreteSeparator); } }