Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Sep 29, 2024
1 parent 3d8603b commit 2cf2100
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 36 deletions.
32 changes: 14 additions & 18 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct HybridGaussianConditional::Helper {
std::vector<GC::shared_ptr> gcs;
fvs.reserve(p.size());
gcs.reserve(p.size());
for (const auto &[mean, sigma] : p) {
for (auto &&[mean, sigma] : p) {
auto gaussianConditional =
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
double value = gaussianConditional->negLogConstant();
Expand Down Expand Up @@ -96,38 +96,34 @@ HybridGaussianConditional::HybridGaussianConditional(
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{mode},
Conditionals({mode}, conditionals)) {}
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, Key key, //
const DiscreteKey &discreteParent, Key key, //
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(DiscreteKeys{mode},
Helper(mode, parameters, key)) {}
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, Key key, //
const DiscreteKey &discreteParent, Key key, //
const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(DiscreteKeys{mode},
Helper(mode, parameters, key, A, parent)) {}
: HybridGaussianConditional(
DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A, parent)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, Key key, //
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(
DiscreteKeys{mode},
Helper(mode, parameters, key, A1, parent1, A2, parent2)) {}
DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
Expand Down
16 changes: 8 additions & 8 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,45 +79,45 @@ class GTSAM_EXPORT HybridGaussianConditional
/**
* @brief Construct from one discrete key and vector of conditionals.
*
* @param mode Single discrete parent variable
* @param discreteParent Single discrete parent variable
* @param conditionals Vector of conditionals with the same size as the
* cardinality of the discrete parent.
*/
HybridGaussianConditional(
const DiscreteKey &mode,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/**
* @brief Constructs a HybridGaussianConditional with means mu_i and
* standard deviations sigma_i.
*
* @param mode The discrete mode key.
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param parameters A vector of pairs (mu_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &mode, Key key,
const DiscreteKey &discreteParent, Key key,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A × parent + b_i and standard deviations sigma_i.
*
* @param mode The discrete mode key.
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A The matrix A.
* @param parent The key of the parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &mode, Key key, const Matrix &A, Key parent,
const DiscreteKey &discreteParent, Key key, const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i.
*
* @param mode The discrete mode key.
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A1 The first matrix.
* @param parent1 The key of the first parent variable.
Expand All @@ -126,7 +126,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &mode, Key key, //
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters);

Expand Down
22 changes: 12 additions & 10 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double mu0 = 1.0, mu1 = 3.0;
double sigma = 2.0;

HybridBayesNet hbn;
// Create a Gaussian mixture model p(z|m) with same sigma.
HybridBayesNet gmm;
std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma},
{Vector1(mu1), sigma}};
hbn.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
hbn.push_back(mixing);
gmm.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
gmm.push_back(mixing);

// At the halfway point between the means, we should get P(m|z)=0.5
double midway = mu1 - mu0;
auto pMid = SolveHBN(hbn, midway);
auto pMid = SolveHBN(gmm, midway);
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));

// Everywhere else, the result should be a sigmoid.
Expand All @@ -96,7 +97,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
const double expected = prob_m_z(mu0, mu1, sigma, sigma, z);

// Workflow 1: convert HBN to HFG and solve
auto posterior1 = SolveHBN(hbn, z);
auto posterior1 = SolveHBN(gmm, z);
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);

// Workflow 2: directly specify HFG and solve
Expand All @@ -117,16 +118,17 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
double mu0 = 1.0, mu1 = 3.0;
double sigma0 = 8.0, sigma1 = 4.0;

HybridBayesNet hbn;
// Create a Gaussian mixture model p(z|m) with same sigma.
HybridBayesNet gmm;
std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma0},
{Vector1(mu1), sigma1}};
hbn.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
hbn.push_back(mixing);
gmm.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
gmm.push_back(mixing);

// We get zMax=3.1333 by finding the maximum value of the function, at which
// point the mode m==1 is about twice as probable as m==0.
double zMax = 3.133;
auto pMax = SolveHBN(hbn, zMax);
auto pMax = SolveHBN(gmm, zMax);
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));

// Everywhere else, the result should be a bell curve like function.
Expand All @@ -135,7 +137,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z);

// Workflow 1: convert HBN to HFG and solve
auto posterior1 = SolveHBN(hbn, z);
auto posterior1 = SolveHBN(gmm, z);
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);

// Workflow 2: directly specify HFG and solve
Expand Down

0 comments on commit 2cf2100

Please sign in to comment.